[Mlir-commits] [mlir] [mlir][NFC] update `Conversion` create APIs (6/n) (#149687) (PR #149888)
Maksim Levental
llvmlistbot at llvm.org
Mon Jul 21 13:04:42 PDT 2025
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/149888
>From d61265ad8917b52942cadcc4a6209f5534b41730 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Mon, 21 Jul 2025 15:26:00 -0400
Subject: [PATCH] [mlir][NFC] update `Conversion` create APIs (4/n) (#149687)
See https://github.com/llvm/llvm-project/pull/147168 for more info.
---
.../Conversion/IndexToLLVM/IndexToLLVM.cpp | 72 ++---
.../Conversion/IndexToSPIRV/IndexToSPIRV.cpp | 85 ++---
.../Conversion/LLVMCommon/MemRefBuilder.cpp | 118 +++----
mlir/lib/Conversion/LLVMCommon/Pattern.cpp | 92 +++---
.../Conversion/LLVMCommon/PrintCallHelper.cpp | 12 +-
.../Conversion/LLVMCommon/StructBuilder.cpp | 4 +-
.../Conversion/LLVMCommon/TypeConverter.cpp | 16 +-
.../Conversion/LLVMCommon/VectorPattern.cpp | 8 +-
.../LinalgToStandard/LinalgToStandard.cpp | 8 +-
mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp | 108 ++++---
.../Conversion/MathToFuncs/MathToFuncs.cpp | 265 +++++++--------
mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp | 80 ++---
mlir/lib/Conversion/MathToLibm/MathToLibm.cpp | 21 +-
.../Conversion/MathToSPIRV/MathToSPIRV.cpp | 112 +++----
.../MemRefToEmitC/MemRefToEmitC.cpp | 14 +-
.../Conversion/MemRefToLLVM/MemRefToLLVM.cpp | 220 +++++++------
.../MemRefToSPIRV/MemRefToSPIRV.cpp | 37 +--
mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp | 303 +++++++++---------
.../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp | 238 +++++++-------
.../Conversion/OpenACCToSCF/OpenACCToSCF.cpp | 4 +-
.../Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp | 4 +-
.../PDLToPDLInterp/PDLToPDLInterp.cpp | 217 +++++++------
22 files changed, 1046 insertions(+), 992 deletions(-)
diff --git a/mlir/lib/Conversion/IndexToLLVM/IndexToLLVM.cpp b/mlir/lib/Conversion/IndexToLLVM/IndexToLLVM.cpp
index 0473bb59fa6aa..99d2f6ca78c38 100644
--- a/mlir/lib/Conversion/IndexToLLVM/IndexToLLVM.cpp
+++ b/mlir/lib/Conversion/IndexToLLVM/IndexToLLVM.cpp
@@ -36,34 +36,34 @@ struct ConvertIndexCeilDivS : mlir::ConvertOpToLLVMPattern<CeilDivSOp> {
Location loc = op.getLoc();
Value n = adaptor.getLhs();
Value m = adaptor.getRhs();
- Value zero = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 0);
- Value posOne = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 1);
- Value negOne = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), -1);
+ Value zero = LLVM::ConstantOp::create(rewriter, loc, n.getType(), 0);
+ Value posOne = LLVM::ConstantOp::create(rewriter, loc, n.getType(), 1);
+ Value negOne = LLVM::ConstantOp::create(rewriter, loc, n.getType(), -1);
// Compute `x`.
Value mPos =
- rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sgt, m, zero);
- Value x = rewriter.create<LLVM::SelectOp>(loc, mPos, negOne, posOne);
+ LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::sgt, m, zero);
+ Value x = LLVM::SelectOp::create(rewriter, loc, mPos, negOne, posOne);
// Compute the positive result.
- Value nPlusX = rewriter.create<LLVM::AddOp>(loc, n, x);
- Value nPlusXDivM = rewriter.create<LLVM::SDivOp>(loc, nPlusX, m);
- Value posRes = rewriter.create<LLVM::AddOp>(loc, nPlusXDivM, posOne);
+ Value nPlusX = LLVM::AddOp::create(rewriter, loc, n, x);
+ Value nPlusXDivM = LLVM::SDivOp::create(rewriter, loc, nPlusX, m);
+ Value posRes = LLVM::AddOp::create(rewriter, loc, nPlusXDivM, posOne);
// Compute the negative result.
- Value negN = rewriter.create<LLVM::SubOp>(loc, zero, n);
- Value negNDivM = rewriter.create<LLVM::SDivOp>(loc, negN, m);
- Value negRes = rewriter.create<LLVM::SubOp>(loc, zero, negNDivM);
+ Value negN = LLVM::SubOp::create(rewriter, loc, zero, n);
+ Value negNDivM = LLVM::SDivOp::create(rewriter, loc, negN, m);
+ Value negRes = LLVM::SubOp::create(rewriter, loc, zero, negNDivM);
// Pick the positive result if `n` and `m` have the same sign and `n` is
// non-zero, i.e. `(n > 0) == (m > 0) && n != 0`.
Value nPos =
- rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sgt, n, zero);
- Value sameSign =
- rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, nPos, mPos);
+ LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::sgt, n, zero);
+ Value sameSign = LLVM::ICmpOp::create(rewriter, loc,
+ LLVM::ICmpPredicate::eq, nPos, mPos);
Value nNonZero =
- rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ne, n, zero);
- Value cmp = rewriter.create<LLVM::AndOp>(loc, sameSign, nNonZero);
+ LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::ne, n, zero);
+ Value cmp = LLVM::AndOp::create(rewriter, loc, sameSign, nNonZero);
rewriter.replaceOpWithNewOp<LLVM::SelectOp>(op, cmp, posRes, negRes);
return success();
}
@@ -83,17 +83,17 @@ struct ConvertIndexCeilDivU : mlir::ConvertOpToLLVMPattern<CeilDivUOp> {
Location loc = op.getLoc();
Value n = adaptor.getLhs();
Value m = adaptor.getRhs();
- Value zero = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 0);
- Value one = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 1);
+ Value zero = LLVM::ConstantOp::create(rewriter, loc, n.getType(), 0);
+ Value one = LLVM::ConstantOp::create(rewriter, loc, n.getType(), 1);
// Compute the non-zero result.
- Value minusOne = rewriter.create<LLVM::SubOp>(loc, n, one);
- Value quotient = rewriter.create<LLVM::UDivOp>(loc, minusOne, m);
- Value plusOne = rewriter.create<LLVM::AddOp>(loc, quotient, one);
+ Value minusOne = LLVM::SubOp::create(rewriter, loc, n, one);
+ Value quotient = LLVM::UDivOp::create(rewriter, loc, minusOne, m);
+ Value plusOne = LLVM::AddOp::create(rewriter, loc, quotient, one);
// Pick the result.
Value cmp =
- rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, n, zero);
+ LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::eq, n, zero);
rewriter.replaceOpWithNewOp<LLVM::SelectOp>(op, cmp, zero, plusOne);
return success();
}
@@ -114,32 +114,32 @@ struct ConvertIndexFloorDivS : mlir::ConvertOpToLLVMPattern<FloorDivSOp> {
Location loc = op.getLoc();
Value n = adaptor.getLhs();
Value m = adaptor.getRhs();
- Value zero = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 0);
- Value posOne = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 1);
- Value negOne = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), -1);
+ Value zero = LLVM::ConstantOp::create(rewriter, loc, n.getType(), 0);
+ Value posOne = LLVM::ConstantOp::create(rewriter, loc, n.getType(), 1);
+ Value negOne = LLVM::ConstantOp::create(rewriter, loc, n.getType(), -1);
// Compute `x`.
Value mNeg =
- rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::slt, m, zero);
- Value x = rewriter.create<LLVM::SelectOp>(loc, mNeg, posOne, negOne);
+ LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::slt, m, zero);
+ Value x = LLVM::SelectOp::create(rewriter, loc, mNeg, posOne, negOne);
// Compute the negative result.
- Value xMinusN = rewriter.create<LLVM::SubOp>(loc, x, n);
- Value xMinusNDivM = rewriter.create<LLVM::SDivOp>(loc, xMinusN, m);
- Value negRes = rewriter.create<LLVM::SubOp>(loc, negOne, xMinusNDivM);
+ Value xMinusN = LLVM::SubOp::create(rewriter, loc, x, n);
+ Value xMinusNDivM = LLVM::SDivOp::create(rewriter, loc, xMinusN, m);
+ Value negRes = LLVM::SubOp::create(rewriter, loc, negOne, xMinusNDivM);
// Compute the positive result.
- Value posRes = rewriter.create<LLVM::SDivOp>(loc, n, m);
+ Value posRes = LLVM::SDivOp::create(rewriter, loc, n, m);
// Pick the negative result if `n` and `m` have different signs and `n` is
// non-zero, i.e. `(n < 0) != (m < 0) && n != 0`.
Value nNeg =
- rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::slt, n, zero);
- Value diffSign =
- rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ne, nNeg, mNeg);
+ LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::slt, n, zero);
+ Value diffSign = LLVM::ICmpOp::create(rewriter, loc,
+ LLVM::ICmpPredicate::ne, nNeg, mNeg);
Value nNonZero =
- rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ne, n, zero);
- Value cmp = rewriter.create<LLVM::AndOp>(loc, diffSign, nNonZero);
+ LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::ne, n, zero);
+ Value cmp = LLVM::AndOp::create(rewriter, loc, diffSign, nNonZero);
rewriter.replaceOpWithNewOp<LLVM::SelectOp>(op, cmp, negRes, posRes);
return success();
}
diff --git a/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp b/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp
index 4821962f989e6..36cfe9dd6e2db 100644
--- a/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp
+++ b/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp
@@ -111,33 +111,33 @@ struct ConvertIndexCeilDivSPattern final : OpConversionPattern<CeilDivSOp> {
Value m = adaptor.getRhs();
// Define the constants
- Value zero = rewriter.create<spirv::ConstantOp>(
- loc, n_type, IntegerAttr::get(n_type, 0));
- Value posOne = rewriter.create<spirv::ConstantOp>(
- loc, n_type, IntegerAttr::get(n_type, 1));
- Value negOne = rewriter.create<spirv::ConstantOp>(
- loc, n_type, IntegerAttr::get(n_type, -1));
+ Value zero = spirv::ConstantOp::create(rewriter, loc, n_type,
+ IntegerAttr::get(n_type, 0));
+ Value posOne = spirv::ConstantOp::create(rewriter, loc, n_type,
+ IntegerAttr::get(n_type, 1));
+ Value negOne = spirv::ConstantOp::create(rewriter, loc, n_type,
+ IntegerAttr::get(n_type, -1));
// Compute `x`.
- Value mPos = rewriter.create<spirv::SGreaterThanOp>(loc, m, zero);
- Value x = rewriter.create<spirv::SelectOp>(loc, mPos, negOne, posOne);
+ Value mPos = spirv::SGreaterThanOp::create(rewriter, loc, m, zero);
+ Value x = spirv::SelectOp::create(rewriter, loc, mPos, negOne, posOne);
// Compute the positive result.
- Value nPlusX = rewriter.create<spirv::IAddOp>(loc, n, x);
- Value nPlusXDivM = rewriter.create<spirv::SDivOp>(loc, nPlusX, m);
- Value posRes = rewriter.create<spirv::IAddOp>(loc, nPlusXDivM, posOne);
+ Value nPlusX = spirv::IAddOp::create(rewriter, loc, n, x);
+ Value nPlusXDivM = spirv::SDivOp::create(rewriter, loc, nPlusX, m);
+ Value posRes = spirv::IAddOp::create(rewriter, loc, nPlusXDivM, posOne);
// Compute the negative result.
- Value negN = rewriter.create<spirv::ISubOp>(loc, zero, n);
- Value negNDivM = rewriter.create<spirv::SDivOp>(loc, negN, m);
- Value negRes = rewriter.create<spirv::ISubOp>(loc, zero, negNDivM);
+ Value negN = spirv::ISubOp::create(rewriter, loc, zero, n);
+ Value negNDivM = spirv::SDivOp::create(rewriter, loc, negN, m);
+ Value negRes = spirv::ISubOp::create(rewriter, loc, zero, negNDivM);
// Pick the positive result if `n` and `m` have the same sign and `n` is
// non-zero, i.e. `(n > 0) == (m > 0) && n != 0`.
- Value nPos = rewriter.create<spirv::SGreaterThanOp>(loc, n, zero);
- Value sameSign = rewriter.create<spirv::LogicalEqualOp>(loc, nPos, mPos);
- Value nNonZero = rewriter.create<spirv::INotEqualOp>(loc, n, zero);
- Value cmp = rewriter.create<spirv::LogicalAndOp>(loc, sameSign, nNonZero);
+ Value nPos = spirv::SGreaterThanOp::create(rewriter, loc, n, zero);
+ Value sameSign = spirv::LogicalEqualOp::create(rewriter, loc, nPos, mPos);
+ Value nNonZero = spirv::INotEqualOp::create(rewriter, loc, n, zero);
+ Value cmp = spirv::LogicalAndOp::create(rewriter, loc, sameSign, nNonZero);
rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, posRes, negRes);
return success();
}
@@ -161,18 +161,18 @@ struct ConvertIndexCeilDivUPattern final : OpConversionPattern<CeilDivUOp> {
Value m = adaptor.getRhs();
// Define the constants
- Value zero = rewriter.create<spirv::ConstantOp>(
- loc, n_type, IntegerAttr::get(n_type, 0));
- Value one = rewriter.create<spirv::ConstantOp>(loc, n_type,
- IntegerAttr::get(n_type, 1));
+ Value zero = spirv::ConstantOp::create(rewriter, loc, n_type,
+ IntegerAttr::get(n_type, 0));
+ Value one = spirv::ConstantOp::create(rewriter, loc, n_type,
+ IntegerAttr::get(n_type, 1));
// Compute the non-zero result.
- Value minusOne = rewriter.create<spirv::ISubOp>(loc, n, one);
- Value quotient = rewriter.create<spirv::UDivOp>(loc, minusOne, m);
- Value plusOne = rewriter.create<spirv::IAddOp>(loc, quotient, one);
+ Value minusOne = spirv::ISubOp::create(rewriter, loc, n, one);
+ Value quotient = spirv::UDivOp::create(rewriter, loc, minusOne, m);
+ Value plusOne = spirv::IAddOp::create(rewriter, loc, quotient, one);
// Pick the result
- Value cmp = rewriter.create<spirv::IEqualOp>(loc, n, zero);
+ Value cmp = spirv::IEqualOp::create(rewriter, loc, n, zero);
rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, zero, plusOne);
return success();
}
@@ -197,32 +197,33 @@ struct ConvertIndexFloorDivSPattern final : OpConversionPattern<FloorDivSOp> {
Value m = adaptor.getRhs();
// Define the constants
- Value zero = rewriter.create<spirv::ConstantOp>(
- loc, n_type, IntegerAttr::get(n_type, 0));
- Value posOne = rewriter.create<spirv::ConstantOp>(
- loc, n_type, IntegerAttr::get(n_type, 1));
- Value negOne = rewriter.create<spirv::ConstantOp>(
- loc, n_type, IntegerAttr::get(n_type, -1));
+ Value zero = spirv::ConstantOp::create(rewriter, loc, n_type,
+ IntegerAttr::get(n_type, 0));
+ Value posOne = spirv::ConstantOp::create(rewriter, loc, n_type,
+ IntegerAttr::get(n_type, 1));
+ Value negOne = spirv::ConstantOp::create(rewriter, loc, n_type,
+ IntegerAttr::get(n_type, -1));
// Compute `x`.
- Value mNeg = rewriter.create<spirv::SLessThanOp>(loc, m, zero);
- Value x = rewriter.create<spirv::SelectOp>(loc, mNeg, posOne, negOne);
+ Value mNeg = spirv::SLessThanOp::create(rewriter, loc, m, zero);
+ Value x = spirv::SelectOp::create(rewriter, loc, mNeg, posOne, negOne);
// Compute the negative result
- Value xMinusN = rewriter.create<spirv::ISubOp>(loc, x, n);
- Value xMinusNDivM = rewriter.create<spirv::SDivOp>(loc, xMinusN, m);
- Value negRes = rewriter.create<spirv::ISubOp>(loc, negOne, xMinusNDivM);
+ Value xMinusN = spirv::ISubOp::create(rewriter, loc, x, n);
+ Value xMinusNDivM = spirv::SDivOp::create(rewriter, loc, xMinusN, m);
+ Value negRes = spirv::ISubOp::create(rewriter, loc, negOne, xMinusNDivM);
// Compute the positive result.
- Value posRes = rewriter.create<spirv::SDivOp>(loc, n, m);
+ Value posRes = spirv::SDivOp::create(rewriter, loc, n, m);
// Pick the negative result if `n` and `m` have different signs and `n` is
// non-zero, i.e. `(n < 0) != (m < 0) && n != 0`.
- Value nNeg = rewriter.create<spirv::SLessThanOp>(loc, n, zero);
- Value diffSign = rewriter.create<spirv::LogicalNotEqualOp>(loc, nNeg, mNeg);
- Value nNonZero = rewriter.create<spirv::INotEqualOp>(loc, n, zero);
+ Value nNeg = spirv::SLessThanOp::create(rewriter, loc, n, zero);
+ Value diffSign =
+ spirv::LogicalNotEqualOp::create(rewriter, loc, nNeg, mNeg);
+ Value nNonZero = spirv::INotEqualOp::create(rewriter, loc, n, zero);
- Value cmp = rewriter.create<spirv::LogicalAndOp>(loc, diffSign, nNonZero);
+ Value cmp = spirv::LogicalAndOp::create(rewriter, loc, diffSign, nNonZero);
rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, posRes, negRes);
return success();
}
diff --git a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
index e34d5f74d232f..fce7a3f324b86 100644
--- a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
@@ -32,7 +32,7 @@ MemRefDescriptor::MemRefDescriptor(Value descriptor)
MemRefDescriptor MemRefDescriptor::poison(OpBuilder &builder, Location loc,
Type descriptorType) {
- Value descriptor = builder.create<LLVM::PoisonOp>(loc, descriptorType);
+ Value descriptor = LLVM::PoisonOp::create(builder, loc, descriptorType);
return MemRefDescriptor(descriptor);
}
@@ -99,21 +99,21 @@ void MemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc,
// integer attribute.
static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
Type resultType, int64_t value) {
- return builder.create<LLVM::ConstantOp>(loc, resultType,
- builder.getIndexAttr(value));
+ return LLVM::ConstantOp::create(builder, loc, resultType,
+ builder.getIndexAttr(value));
}
/// Builds IR extracting the offset from the descriptor.
Value MemRefDescriptor::offset(OpBuilder &builder, Location loc) {
- return builder.create<LLVM::ExtractValueOp>(loc, value,
- kOffsetPosInMemRefDescriptor);
+ return LLVM::ExtractValueOp::create(builder, loc, value,
+ kOffsetPosInMemRefDescriptor);
}
/// Builds IR inserting the offset into the descriptor.
void MemRefDescriptor::setOffset(OpBuilder &builder, Location loc,
Value offset) {
- value = builder.create<LLVM::InsertValueOp>(loc, value, offset,
- kOffsetPosInMemRefDescriptor);
+ value = LLVM::InsertValueOp::create(builder, loc, value, offset,
+ kOffsetPosInMemRefDescriptor);
}
/// Builds IR inserting the offset into the descriptor.
@@ -125,8 +125,9 @@ void MemRefDescriptor::setConstantOffset(OpBuilder &builder, Location loc,
/// Builds IR extracting the pos-th size from the descriptor.
Value MemRefDescriptor::size(OpBuilder &builder, Location loc, unsigned pos) {
- return builder.create<LLVM::ExtractValueOp>(
- loc, value, ArrayRef<int64_t>({kSizePosInMemRefDescriptor, pos}));
+ return LLVM::ExtractValueOp::create(
+ builder, loc, value,
+ ArrayRef<int64_t>({kSizePosInMemRefDescriptor, pos}));
}
Value MemRefDescriptor::size(OpBuilder &builder, Location loc, Value pos,
@@ -137,23 +138,25 @@ Value MemRefDescriptor::size(OpBuilder &builder, Location loc, Value pos,
// Copy size values to stack-allocated memory.
auto one = createIndexAttrConstant(builder, loc, indexType, 1);
- auto sizes = builder.create<LLVM::ExtractValueOp>(
- loc, value, llvm::ArrayRef<int64_t>({kSizePosInMemRefDescriptor}));
- auto sizesPtr = builder.create<LLVM::AllocaOp>(loc, ptrTy, arrayTy, one,
- /*alignment=*/0);
- builder.create<LLVM::StoreOp>(loc, sizes, sizesPtr);
+ auto sizes = LLVM::ExtractValueOp::create(
+ builder, loc, value,
+ llvm::ArrayRef<int64_t>({kSizePosInMemRefDescriptor}));
+ auto sizesPtr = LLVM::AllocaOp::create(builder, loc, ptrTy, arrayTy, one,
+ /*alignment=*/0);
+ LLVM::StoreOp::create(builder, loc, sizes, sizesPtr);
// Load an return size value of interest.
- auto resultPtr = builder.create<LLVM::GEPOp>(loc, ptrTy, arrayTy, sizesPtr,
- ArrayRef<LLVM::GEPArg>{0, pos});
- return builder.create<LLVM::LoadOp>(loc, indexType, resultPtr);
+ auto resultPtr = LLVM::GEPOp::create(builder, loc, ptrTy, arrayTy, sizesPtr,
+ ArrayRef<LLVM::GEPArg>{0, pos});
+ return LLVM::LoadOp::create(builder, loc, indexType, resultPtr);
}
/// Builds IR inserting the pos-th size into the descriptor
void MemRefDescriptor::setSize(OpBuilder &builder, Location loc, unsigned pos,
Value size) {
- value = builder.create<LLVM::InsertValueOp>(
- loc, value, size, ArrayRef<int64_t>({kSizePosInMemRefDescriptor, pos}));
+ value = LLVM::InsertValueOp::create(
+ builder, loc, value, size,
+ ArrayRef<int64_t>({kSizePosInMemRefDescriptor, pos}));
}
void MemRefDescriptor::setConstantSize(OpBuilder &builder, Location loc,
@@ -164,15 +167,16 @@ void MemRefDescriptor::setConstantSize(OpBuilder &builder, Location loc,
/// Builds IR extracting the pos-th stride from the descriptor.
Value MemRefDescriptor::stride(OpBuilder &builder, Location loc, unsigned pos) {
- return builder.create<LLVM::ExtractValueOp>(
- loc, value, ArrayRef<int64_t>({kStridePosInMemRefDescriptor, pos}));
+ return LLVM::ExtractValueOp::create(
+ builder, loc, value,
+ ArrayRef<int64_t>({kStridePosInMemRefDescriptor, pos}));
}
/// Builds IR inserting the pos-th stride into the descriptor
void MemRefDescriptor::setStride(OpBuilder &builder, Location loc, unsigned pos,
Value stride) {
- value = builder.create<LLVM::InsertValueOp>(
- loc, value, stride,
+ value = LLVM::InsertValueOp::create(
+ builder, loc, value, stride,
ArrayRef<int64_t>({kStridePosInMemRefDescriptor, pos}));
}
@@ -207,8 +211,8 @@ Value MemRefDescriptor::bufferPtr(OpBuilder &builder, Location loc,
? offset(builder, loc)
: createIndexAttrConstant(builder, loc, indexType, offsetCst);
Type elementType = converter.convertType(type.getElementType());
- ptr = builder.create<LLVM::GEPOp>(loc, ptr.getType(), elementType, ptr,
- offsetVal);
+ ptr = LLVM::GEPOp::create(builder, loc, ptr.getType(), elementType, ptr,
+ offsetVal);
return ptr;
}
@@ -303,7 +307,7 @@ UnrankedMemRefDescriptor::UnrankedMemRefDescriptor(Value descriptor)
UnrankedMemRefDescriptor UnrankedMemRefDescriptor::poison(OpBuilder &builder,
Location loc,
Type descriptorType) {
- Value descriptor = builder.create<LLVM::PoisonOp>(loc, descriptorType);
+ Value descriptor = LLVM::PoisonOp::create(builder, loc, descriptorType);
return UnrankedMemRefDescriptor(descriptor);
}
Value UnrankedMemRefDescriptor::rank(OpBuilder &builder, Location loc) const {
@@ -380,19 +384,19 @@ void UnrankedMemRefDescriptor::computeSizes(
builder, loc, indexType,
llvm::divideCeil(typeConverter.getPointerBitwidth(addressSpace), 8));
Value doublePointerSize =
- builder.create<LLVM::MulOp>(loc, indexType, two, pointerSize);
+ LLVM::MulOp::create(builder, loc, indexType, two, pointerSize);
// (1 + 2 * rank) * sizeof(index)
Value rank = desc.rank(builder, loc);
- Value doubleRank = builder.create<LLVM::MulOp>(loc, indexType, two, rank);
+ Value doubleRank = LLVM::MulOp::create(builder, loc, indexType, two, rank);
Value doubleRankIncremented =
- builder.create<LLVM::AddOp>(loc, indexType, doubleRank, one);
- Value rankIndexSize = builder.create<LLVM::MulOp>(
- loc, indexType, doubleRankIncremented, indexSize);
+ LLVM::AddOp::create(builder, loc, indexType, doubleRank, one);
+ Value rankIndexSize = LLVM::MulOp::create(builder, loc, indexType,
+ doubleRankIncremented, indexSize);
// Total allocation size.
- Value allocationSize = builder.create<LLVM::AddOp>(
- loc, indexType, doublePointerSize, rankIndexSize);
+ Value allocationSize = LLVM::AddOp::create(
+ builder, loc, indexType, doublePointerSize, rankIndexSize);
sizes.push_back(allocationSize);
}
}
@@ -400,13 +404,13 @@ void UnrankedMemRefDescriptor::computeSizes(
Value UnrankedMemRefDescriptor::allocatedPtr(
OpBuilder &builder, Location loc, Value memRefDescPtr,
LLVM::LLVMPointerType elemPtrType) {
- return builder.create<LLVM::LoadOp>(loc, elemPtrType, memRefDescPtr);
+ return LLVM::LoadOp::create(builder, loc, elemPtrType, memRefDescPtr);
}
void UnrankedMemRefDescriptor::setAllocatedPtr(
OpBuilder &builder, Location loc, Value memRefDescPtr,
LLVM::LLVMPointerType elemPtrType, Value allocatedPtr) {
- builder.create<LLVM::StoreOp>(loc, allocatedPtr, memRefDescPtr);
+ LLVM::StoreOp::create(builder, loc, allocatedPtr, memRefDescPtr);
}
static std::pair<Value, Type>
@@ -423,9 +427,9 @@ Value UnrankedMemRefDescriptor::alignedPtr(
castToElemPtrPtr(builder, loc, memRefDescPtr, elemPtrType);
Value alignedGep =
- builder.create<LLVM::GEPOp>(loc, elemPtrPtrType, elemPtrType,
- elementPtrPtr, ArrayRef<LLVM::GEPArg>{1});
- return builder.create<LLVM::LoadOp>(loc, elemPtrType, alignedGep);
+ LLVM::GEPOp::create(builder, loc, elemPtrPtrType, elemPtrType,
+ elementPtrPtr, ArrayRef<LLVM::GEPArg>{1});
+ return LLVM::LoadOp::create(builder, loc, elemPtrType, alignedGep);
}
void UnrankedMemRefDescriptor::setAlignedPtr(
@@ -435,9 +439,9 @@ void UnrankedMemRefDescriptor::setAlignedPtr(
castToElemPtrPtr(builder, loc, memRefDescPtr, elemPtrType);
Value alignedGep =
- builder.create<LLVM::GEPOp>(loc, elemPtrPtrType, elemPtrType,
- elementPtrPtr, ArrayRef<LLVM::GEPArg>{1});
- builder.create<LLVM::StoreOp>(loc, alignedPtr, alignedGep);
+ LLVM::GEPOp::create(builder, loc, elemPtrPtrType, elemPtrType,
+ elementPtrPtr, ArrayRef<LLVM::GEPArg>{1});
+ LLVM::StoreOp::create(builder, loc, alignedPtr, alignedGep);
}
Value UnrankedMemRefDescriptor::offsetBasePtr(
@@ -446,8 +450,8 @@ Value UnrankedMemRefDescriptor::offsetBasePtr(
auto [elementPtrPtr, elemPtrPtrType] =
castToElemPtrPtr(builder, loc, memRefDescPtr, elemPtrType);
- return builder.create<LLVM::GEPOp>(loc, elemPtrPtrType, elemPtrType,
- elementPtrPtr, ArrayRef<LLVM::GEPArg>{2});
+ return LLVM::GEPOp::create(builder, loc, elemPtrPtrType, elemPtrType,
+ elementPtrPtr, ArrayRef<LLVM::GEPArg>{2});
}
Value UnrankedMemRefDescriptor::offset(OpBuilder &builder, Location loc,
@@ -456,8 +460,8 @@ Value UnrankedMemRefDescriptor::offset(OpBuilder &builder, Location loc,
LLVM::LLVMPointerType elemPtrType) {
Value offsetPtr =
offsetBasePtr(builder, loc, typeConverter, memRefDescPtr, elemPtrType);
- return builder.create<LLVM::LoadOp>(loc, typeConverter.getIndexType(),
- offsetPtr);
+ return LLVM::LoadOp::create(builder, loc, typeConverter.getIndexType(),
+ offsetPtr);
}
void UnrankedMemRefDescriptor::setOffset(OpBuilder &builder, Location loc,
@@ -467,7 +471,7 @@ void UnrankedMemRefDescriptor::setOffset(OpBuilder &builder, Location loc,
Value offset) {
Value offsetPtr =
offsetBasePtr(builder, loc, typeConverter, memRefDescPtr, elemPtrType);
- builder.create<LLVM::StoreOp>(loc, offset, offsetPtr);
+ LLVM::StoreOp::create(builder, loc, offset, offsetPtr);
}
Value UnrankedMemRefDescriptor::sizeBasePtr(
@@ -477,8 +481,8 @@ Value UnrankedMemRefDescriptor::sizeBasePtr(
Type structTy = LLVM::LLVMStructType::getLiteral(
indexTy.getContext(), {elemPtrType, elemPtrType, indexTy, indexTy});
auto resultType = LLVM::LLVMPointerType::get(builder.getContext());
- return builder.create<LLVM::GEPOp>(loc, resultType, structTy, memRefDescPtr,
- ArrayRef<LLVM::GEPArg>{0, 3});
+ return LLVM::GEPOp::create(builder, loc, resultType, structTy, memRefDescPtr,
+ ArrayRef<LLVM::GEPArg>{0, 3});
}
Value UnrankedMemRefDescriptor::size(OpBuilder &builder, Location loc,
@@ -489,8 +493,8 @@ Value UnrankedMemRefDescriptor::size(OpBuilder &builder, Location loc,
auto ptrType = LLVM::LLVMPointerType::get(builder.getContext());
Value sizeStoreGep =
- builder.create<LLVM::GEPOp>(loc, ptrType, indexTy, sizeBasePtr, index);
- return builder.create<LLVM::LoadOp>(loc, indexTy, sizeStoreGep);
+ LLVM::GEPOp::create(builder, loc, ptrType, indexTy, sizeBasePtr, index);
+ return LLVM::LoadOp::create(builder, loc, indexTy, sizeStoreGep);
}
void UnrankedMemRefDescriptor::setSize(OpBuilder &builder, Location loc,
@@ -501,8 +505,8 @@ void UnrankedMemRefDescriptor::setSize(OpBuilder &builder, Location loc,
auto ptrType = LLVM::LLVMPointerType::get(builder.getContext());
Value sizeStoreGep =
- builder.create<LLVM::GEPOp>(loc, ptrType, indexTy, sizeBasePtr, index);
- builder.create<LLVM::StoreOp>(loc, size, sizeStoreGep);
+ LLVM::GEPOp::create(builder, loc, ptrType, indexTy, sizeBasePtr, index);
+ LLVM::StoreOp::create(builder, loc, size, sizeStoreGep);
}
Value UnrankedMemRefDescriptor::strideBasePtr(
@@ -511,7 +515,7 @@ Value UnrankedMemRefDescriptor::strideBasePtr(
Type indexTy = typeConverter.getIndexType();
auto ptrType = LLVM::LLVMPointerType::get(builder.getContext());
- return builder.create<LLVM::GEPOp>(loc, ptrType, indexTy, sizeBasePtr, rank);
+ return LLVM::GEPOp::create(builder, loc, ptrType, indexTy, sizeBasePtr, rank);
}
Value UnrankedMemRefDescriptor::stride(OpBuilder &builder, Location loc,
@@ -522,8 +526,8 @@ Value UnrankedMemRefDescriptor::stride(OpBuilder &builder, Location loc,
auto ptrType = LLVM::LLVMPointerType::get(builder.getContext());
Value strideStoreGep =
- builder.create<LLVM::GEPOp>(loc, ptrType, indexTy, strideBasePtr, index);
- return builder.create<LLVM::LoadOp>(loc, indexTy, strideStoreGep);
+ LLVM::GEPOp::create(builder, loc, ptrType, indexTy, strideBasePtr, index);
+ return LLVM::LoadOp::create(builder, loc, indexTy, strideStoreGep);
}
void UnrankedMemRefDescriptor::setStride(OpBuilder &builder, Location loc,
@@ -534,6 +538,6 @@ void UnrankedMemRefDescriptor::setStride(OpBuilder &builder, Location loc,
auto ptrType = LLVM::LLVMPointerType::get(builder.getContext());
Value strideStoreGep =
- builder.create<LLVM::GEPOp>(loc, ptrType, indexTy, strideBasePtr, index);
- builder.create<LLVM::StoreOp>(loc, stride, strideStoreGep);
+ LLVM::GEPOp::create(builder, loc, ptrType, indexTy, strideBasePtr, index);
+ LLVM::StoreOp::create(builder, loc, stride, strideStoreGep);
}
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index c5f72f7e10b8c..ecd5b6367fba4 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -57,8 +57,8 @@ Value ConvertToLLVMPattern::createIndexAttrConstant(OpBuilder &builder,
Location loc,
Type resultType,
int64_t value) {
- return builder.create<LLVM::ConstantOp>(loc, resultType,
- builder.getIndexAttr(value));
+ return LLVM::ConstantOp::create(builder, loc, resultType,
+ builder.getIndexAttr(value));
}
Value ConvertToLLVMPattern::getStridedElementPtr(
@@ -123,7 +123,7 @@ void ConvertToLLVMPattern::getMemRefDescriptorSizes(
runningStride = sizes[i];
else if (stride == ShapedType::kDynamic)
runningStride =
- rewriter.create<LLVM::MulOp>(loc, runningStride, sizes[i]);
+ LLVM::MulOp::create(rewriter, loc, runningStride, sizes[i]);
else
runningStride = createIndexAttrConstant(rewriter, loc, indexType, stride);
}
@@ -131,10 +131,10 @@ void ConvertToLLVMPattern::getMemRefDescriptorSizes(
// Buffer size in bytes.
Type elementType = typeConverter->convertType(memRefType.getElementType());
auto elementPtrType = LLVM::LLVMPointerType::get(rewriter.getContext());
- Value nullPtr = rewriter.create<LLVM::ZeroOp>(loc, elementPtrType);
- Value gepPtr = rewriter.create<LLVM::GEPOp>(
- loc, elementPtrType, elementType, nullPtr, runningStride);
- size = rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
+ Value nullPtr = LLVM::ZeroOp::create(rewriter, loc, elementPtrType);
+ Value gepPtr = LLVM::GEPOp::create(rewriter, loc, elementPtrType,
+ elementType, nullPtr, runningStride);
+ size = LLVM::PtrToIntOp::create(rewriter, loc, getIndexType(), gepPtr);
} else {
size = runningStride;
}
@@ -149,10 +149,10 @@ Value ConvertToLLVMPattern::getSizeInBytes(
// which is a common pattern of getting the size of a type in bytes.
Type llvmType = typeConverter->convertType(type);
auto convertedPtrType = LLVM::LLVMPointerType::get(rewriter.getContext());
- auto nullPtr = rewriter.create<LLVM::ZeroOp>(loc, convertedPtrType);
- auto gep = rewriter.create<LLVM::GEPOp>(loc, convertedPtrType, llvmType,
- nullPtr, ArrayRef<LLVM::GEPArg>{1});
- return rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep);
+ auto nullPtr = LLVM::ZeroOp::create(rewriter, loc, convertedPtrType);
+ auto gep = LLVM::GEPOp::create(rewriter, loc, convertedPtrType, llvmType,
+ nullPtr, ArrayRef<LLVM::GEPArg>{1});
+ return LLVM::PtrToIntOp::create(rewriter, loc, getIndexType(), gep);
}
Value ConvertToLLVMPattern::getNumElements(
@@ -175,7 +175,7 @@ Value ConvertToLLVMPattern::getNumElements(
staticSize == ShapedType::kDynamic
? dynamicSizes[dynamicIndex++]
: createIndexAttrConstant(rewriter, loc, indexType, staticSize);
- numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size);
+ numElements = LLVM::MulOp::create(rewriter, loc, numElements, size);
} else {
numElements =
staticSize == ShapedType::kDynamic
@@ -276,14 +276,14 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
? builder
.create<LLVM::CallOp>(loc, mallocFunc.value(), allocationSize)
.getResult()
- : builder.create<LLVM::AllocaOp>(loc, getPtrType(),
- IntegerType::get(getContext(), 8),
- allocationSize,
- /*alignment=*/0);
+ : LLVM::AllocaOp::create(builder, loc, getPtrType(),
+ IntegerType::get(getContext(), 8),
+ allocationSize,
+ /*alignment=*/0);
Value source = desc.memRefDescPtr(builder, loc);
- builder.create<LLVM::MemcpyOp>(loc, memory, source, allocationSize, false);
+ LLVM::MemcpyOp::create(builder, loc, memory, source, allocationSize, false);
if (!toDynamic)
- builder.create<LLVM::CallOp>(loc, freeFunc.value(), source);
+ LLVM::CallOp::create(builder, loc, freeFunc.value(), source);
// Create a new descriptor. The same descriptor can be returned multiple
// times, attempting to modify its pointer can lead to memory leaks
@@ -349,8 +349,8 @@ LogicalResult LLVM::detail::oneToOneRewrite(
SmallVector<Value, 4> results;
results.reserve(numResults);
for (unsigned i = 0; i < numResults; ++i) {
- results.push_back(rewriter.create<LLVM::ExtractValueOp>(
- op->getLoc(), newOp->getResult(0), i));
+ results.push_back(LLVM::ExtractValueOp::create(rewriter, op->getLoc(),
+ newOp->getResult(0), i));
}
rewriter.replaceOp(op, results);
return success();
@@ -371,8 +371,8 @@ LogicalResult LLVM::detail::intrinsicRewrite(
if (numResults != 0)
resType = typeConverter.packOperationResults(op->getResultTypes());
- auto callIntrOp = rewriter.create<LLVM::CallIntrinsicOp>(
- loc, resType, rewriter.getStringAttr(intrinsic), operands);
+ auto callIntrOp = LLVM::CallIntrinsicOp::create(
+ rewriter, loc, resType, rewriter.getStringAttr(intrinsic), operands);
// Propagate attributes.
callIntrOp->setAttrs(op->getAttrDictionary());
@@ -388,7 +388,7 @@ LogicalResult LLVM::detail::intrinsicRewrite(
results.reserve(numResults);
Value intrRes = callIntrOp.getResults();
for (unsigned i = 0; i < numResults; ++i)
- results.push_back(rewriter.create<LLVM::ExtractValueOp>(loc, intrRes, i));
+ results.push_back(LLVM::ExtractValueOp::create(rewriter, loc, intrRes, i));
rewriter.replaceOp(op, results);
return success();
@@ -406,7 +406,7 @@ static unsigned getBitWidth(Type type) {
static Value createI32Constant(OpBuilder &builder, Location loc,
int32_t value) {
Type i32 = builder.getI32Type();
- return builder.create<LLVM::ConstantOp>(loc, i32, value);
+ return LLVM::ConstantOp::create(builder, loc, i32, value);
}
SmallVector<Value> mlir::LLVM::decomposeValue(OpBuilder &builder, Location loc,
@@ -418,17 +418,17 @@ SmallVector<Value> mlir::LLVM::decomposeValue(OpBuilder &builder, Location loc,
unsigned srcBitWidth = getBitWidth(srcType);
unsigned dstBitWidth = getBitWidth(dstType);
if (srcBitWidth == dstBitWidth) {
- Value cast = builder.create<LLVM::BitcastOp>(loc, dstType, src);
+ Value cast = LLVM::BitcastOp::create(builder, loc, dstType, src);
return {cast};
}
if (dstBitWidth > srcBitWidth) {
auto smallerInt = builder.getIntegerType(srcBitWidth);
if (srcType != smallerInt)
- src = builder.create<LLVM::BitcastOp>(loc, smallerInt, src);
+ src = LLVM::BitcastOp::create(builder, loc, smallerInt, src);
auto largerInt = builder.getIntegerType(dstBitWidth);
- Value res = builder.create<LLVM::ZExtOp>(loc, largerInt, src);
+ Value res = LLVM::ZExtOp::create(builder, loc, largerInt, src);
return {res};
}
assert(srcBitWidth % dstBitWidth == 0 &&
@@ -436,12 +436,12 @@ SmallVector<Value> mlir::LLVM::decomposeValue(OpBuilder &builder, Location loc,
int64_t numElements = srcBitWidth / dstBitWidth;
auto vecType = VectorType::get(numElements, dstType);
- src = builder.create<LLVM::BitcastOp>(loc, vecType, src);
+ src = LLVM::BitcastOp::create(builder, loc, vecType, src);
SmallVector<Value> res;
for (auto i : llvm::seq(numElements)) {
Value idx = createI32Constant(builder, loc, i);
- Value elem = builder.create<LLVM::ExtractElementOp>(loc, src, idx);
+ Value elem = LLVM::ExtractElementOp::create(builder, loc, src, idx);
res.emplace_back(elem);
}
@@ -461,28 +461,28 @@ Value mlir::LLVM::composeValue(OpBuilder &builder, Location loc, ValueRange src,
if (dstBitWidth < srcBitWidth) {
auto largerInt = builder.getIntegerType(srcBitWidth);
if (res.getType() != largerInt)
- res = builder.create<LLVM::BitcastOp>(loc, largerInt, res);
+ res = LLVM::BitcastOp::create(builder, loc, largerInt, res);
auto smallerInt = builder.getIntegerType(dstBitWidth);
- res = builder.create<LLVM::TruncOp>(loc, smallerInt, res);
+ res = LLVM::TruncOp::create(builder, loc, smallerInt, res);
}
if (res.getType() != dstType)
- res = builder.create<LLVM::BitcastOp>(loc, dstType, res);
+ res = LLVM::BitcastOp::create(builder, loc, dstType, res);
return res;
}
int64_t numElements = src.size();
auto srcType = VectorType::get(numElements, src.front().getType());
- Value res = builder.create<LLVM::PoisonOp>(loc, srcType);
+ Value res = LLVM::PoisonOp::create(builder, loc, srcType);
for (auto &&[i, elem] : llvm::enumerate(src)) {
Value idx = createI32Constant(builder, loc, i);
- res = builder.create<LLVM::InsertElementOp>(loc, srcType, res, elem, idx);
+ res = LLVM::InsertElementOp::create(builder, loc, srcType, res, elem, idx);
}
if (res.getType() != dstType)
- res = builder.create<LLVM::BitcastOp>(loc, dstType, res);
+ res = LLVM::BitcastOp::create(builder, loc, dstType, res);
return res;
}
@@ -518,20 +518,20 @@ Value mlir::LLVM::getStridedElementPtr(OpBuilder &builder, Location loc,
Value stride =
ShapedType::isDynamic(strides[i])
? memRefDescriptor.stride(builder, loc, i)
- : builder.create<LLVM::ConstantOp>(
- loc, indexType, builder.getIndexAttr(strides[i]));
- increment =
- builder.create<LLVM::MulOp>(loc, increment, stride, intOverflowFlags);
+ : LLVM::ConstantOp::create(builder, loc, indexType,
+ builder.getIndexAttr(strides[i]));
+ increment = LLVM::MulOp::create(builder, loc, increment, stride,
+ intOverflowFlags);
}
- index = index ? builder.create<LLVM::AddOp>(loc, index, increment,
- intOverflowFlags)
+ index = index ? LLVM::AddOp::create(builder, loc, index, increment,
+ intOverflowFlags)
: increment;
}
Type elementPtrType = memRefDescriptor.getElementPtrType();
- return index ? builder.create<LLVM::GEPOp>(
- loc, elementPtrType,
- converter.convertType(type.getElementType()), base, index,
- noWrapFlags)
- : base;
+ return index
+ ? LLVM::GEPOp::create(builder, loc, elementPtrType,
+ converter.convertType(type.getElementType()),
+ base, index, noWrapFlags)
+ : base;
}
diff --git a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
index 49c73fbc9dd79..d95aeba8a4488 100644
--- a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
@@ -66,23 +66,23 @@ LogicalResult mlir::LLVM::createPrintStrCall(
DenseElementsAttr::get(dataAttrType, llvm::ArrayRef(elementVals));
auto arrayTy =
LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), elementVals.size());
- auto globalOp = builder.create<LLVM::GlobalOp>(
- loc, arrayTy, /*constant=*/true, LLVM::Linkage::Private,
+ auto globalOp = LLVM::GlobalOp::create(
+ builder, loc, arrayTy, /*constant=*/true, LLVM::Linkage::Private,
ensureSymbolNameIsUnique(moduleOp, symbolName, symbolTables), dataAttr);
auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext());
// Emit call to `printStr` in runtime library.
builder.restoreInsertionPoint(ip);
auto msgAddr =
- builder.create<LLVM::AddressOfOp>(loc, ptrTy, globalOp.getName());
+ LLVM::AddressOfOp::create(builder, loc, ptrTy, globalOp.getName());
SmallVector<LLVM::GEPArg> indices(1, 0);
Value gep =
- builder.create<LLVM::GEPOp>(loc, ptrTy, arrayTy, msgAddr, indices);
+ LLVM::GEPOp::create(builder, loc, ptrTy, arrayTy, msgAddr, indices);
FailureOr<LLVM::LLVMFuncOp> printer =
LLVM::lookupOrCreatePrintStringFn(builder, moduleOp, runtimeFunctionName);
if (failed(printer))
return failure();
- builder.create<LLVM::CallOp>(loc, TypeRange(),
- SymbolRefAttr::get(printer.value()), gep);
+ LLVM::CallOp::create(builder, loc, TypeRange(),
+ SymbolRefAttr::get(printer.value()), gep);
return success();
}
diff --git a/mlir/lib/Conversion/LLVMCommon/StructBuilder.cpp b/mlir/lib/Conversion/LLVMCommon/StructBuilder.cpp
index 1cd0bd85f9894..13ed4628c3c9e 100644
--- a/mlir/lib/Conversion/LLVMCommon/StructBuilder.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/StructBuilder.cpp
@@ -24,10 +24,10 @@ StructBuilder::StructBuilder(Value v) : value(v), structType(v.getType()) {
Value StructBuilder::extractPtr(OpBuilder &builder, Location loc,
unsigned pos) const {
- return builder.create<LLVM::ExtractValueOp>(loc, value, pos);
+ return LLVM::ExtractValueOp::create(builder, loc, value, pos);
}
void StructBuilder::setPtr(OpBuilder &builder, Location loc, unsigned pos,
Value ptr) {
- value = builder.create<LLVM::InsertValueOp>(loc, value, ptr, pos);
+ value = LLVM::InsertValueOp::create(builder, loc, value, ptr, pos);
}
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 7312594c761f7..1a9bf569086da 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -91,7 +91,7 @@ static Value unrankedMemRefMaterialization(OpBuilder &builder,
packUnrankedMemRefDesc(builder, resultType, inputs, loc, converter);
if (!packed)
return Value();
- return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed)
+ return UnrealizedConversionCastOp::create(builder, loc, resultType, packed)
.getResult(0);
}
@@ -107,7 +107,7 @@ static Value rankedMemRefMaterialization(OpBuilder &builder,
packRankedMemRefDesc(builder, resultType, inputs, loc, converter);
if (!packed)
return Value();
- return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed)
+ return UnrealizedConversionCastOp::create(builder, loc, resultType, packed)
.getResult(0);
}
@@ -224,12 +224,12 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
// non-LLVM types persist after an LLVM conversion.
addSourceMaterialization([&](OpBuilder &builder, Type resultType,
ValueRange inputs, Location loc) {
- return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
+ return UnrealizedConversionCastOp::create(builder, loc, resultType, inputs)
.getResult(0);
});
addTargetMaterialization([&](OpBuilder &builder, Type resultType,
ValueRange inputs, Location loc) {
- return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
+ return UnrealizedConversionCastOp::create(builder, loc, resultType, inputs)
.getResult(0);
});
@@ -731,12 +731,12 @@ Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand,
// Alloca with proper alignment. We do not expect optimizations of this
// alloca op and so we omit allocating at the entry block.
auto ptrType = LLVM::LLVMPointerType::get(builder.getContext());
- Value one = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(),
- builder.getIndexAttr(1));
+ Value one = LLVM::ConstantOp::create(builder, loc, builder.getI64Type(),
+ builder.getIndexAttr(1));
Value allocated =
- builder.create<LLVM::AllocaOp>(loc, ptrType, operand.getType(), one);
+ LLVM::AllocaOp::create(builder, loc, ptrType, operand.getType(), one);
// Store into the alloca'ed descriptor.
- builder.create<LLVM::StoreOp>(loc, operand, allocated);
+ LLVM::StoreOp::create(builder, loc, operand, allocated);
return allocated;
}
diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
index bf3f31729c3da..e7dd0b506e12d 100644
--- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
@@ -87,17 +87,17 @@ LogicalResult LLVM::detail::handleMultidimensionalVectors(
auto result1DVectorTy = resultTypeInfo.llvm1DVectorTy;
auto resultNDVectoryTy = resultTypeInfo.llvmNDVectorTy;
auto loc = op->getLoc();
- Value desc = rewriter.create<LLVM::PoisonOp>(loc, resultNDVectoryTy);
+ Value desc = LLVM::PoisonOp::create(rewriter, loc, resultNDVectoryTy);
nDVectorIterate(resultTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
// For this unrolled `position` corresponding to the `linearIndex`^th
// element, extract operand vectors
SmallVector<Value, 4> extractedOperands;
for (const auto &operand : llvm::enumerate(operands)) {
- extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>(
- loc, operand.value(), position));
+ extractedOperands.push_back(LLVM::ExtractValueOp::create(
+ rewriter, loc, operand.value(), position));
}
Value newVal = createOperand(result1DVectorTy, extractedOperands);
- desc = rewriter.create<LLVM::InsertValueOp>(loc, desc, newVal, position);
+ desc = LLVM::InsertValueOp::create(rewriter, loc, desc, newVal, position);
});
rewriter.replaceOp(op, desc);
return success();
diff --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
index c3f213147b7a7..3f4b4d6cbc8ab 100644
--- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
+++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
@@ -78,8 +78,8 @@ getLibraryCallSymbolRef(Operation *op, PatternRewriter &rewriter) {
// Insert before module terminator.
rewriter.setInsertionPoint(module.getBody(),
std::prev(module.getBody()->end()));
- func::FuncOp funcOp = rewriter.create<func::FuncOp>(
- op->getLoc(), fnNameAttr.getValue(), libFnType);
+ func::FuncOp funcOp = func::FuncOp::create(rewriter, op->getLoc(),
+ fnNameAttr.getValue(), libFnType);
// Insert a function attribute that will trigger the emission of the
// corresponding `_mlir_ciface_xxx` interface so that external libraries see
// a normalized ABI. This interface is added during std to llvm conversion.
@@ -100,8 +100,8 @@ createTypeCanonicalizedMemRefOperands(OpBuilder &b, Location loc,
res.push_back(op);
continue;
}
- Value cast =
- b.create<memref::CastOp>(loc, makeStridedLayoutDynamic(memrefType), op);
+ Value cast = memref::CastOp::create(
+ b, loc, makeStridedLayoutDynamic(memrefType), op);
res.push_back(cast);
}
return res;
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index d4deff5b88070..5b68eb8188996 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -54,18 +54,18 @@ std::pair<Value, Value> getRawPtrAndSize(const Location loc,
Value memRef, Type elType) {
Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
Value dataPtr =
- rewriter.create<LLVM::ExtractValueOp>(loc, ptrType, memRef, 1);
- Value offset = rewriter.create<LLVM::ExtractValueOp>(
- loc, rewriter.getI64Type(), memRef, 2);
+ LLVM::ExtractValueOp::create(rewriter, loc, ptrType, memRef, 1);
+ Value offset = LLVM::ExtractValueOp::create(rewriter, loc,
+ rewriter.getI64Type(), memRef, 2);
Value resPtr =
- rewriter.create<LLVM::GEPOp>(loc, ptrType, elType, dataPtr, offset);
+ LLVM::GEPOp::create(rewriter, loc, ptrType, elType, dataPtr, offset);
Value size;
if (cast<LLVM::LLVMStructType>(memRef.getType()).getBody().size() > 3) {
- size = rewriter.create<LLVM::ExtractValueOp>(loc, memRef,
- ArrayRef<int64_t>{3, 0});
- size = rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), size);
+ size = LLVM::ExtractValueOp::create(rewriter, loc, memRef,
+ ArrayRef<int64_t>{3, 0});
+ size = LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(), size);
} else {
- size = rewriter.create<arith::ConstantIntOp>(loc, 1, 32);
+ size = arith::ConstantIntOp::create(rewriter, loc, 1, 32);
}
return {resPtr, size};
}
@@ -157,13 +157,13 @@ class MPICHImplTraits : public MPIImplTraits {
Value getCommWorld(const Location loc,
ConversionPatternRewriter &rewriter) override {
static constexpr int MPI_COMM_WORLD = 0x44000000;
- return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(),
- MPI_COMM_WORLD);
+ return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(),
+ MPI_COMM_WORLD);
}
Value castComm(const Location loc, ConversionPatternRewriter &rewriter,
Value comm) override {
- return rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), comm);
+ return LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(), comm);
}
intptr_t getStatusIgnore() override { return 1; }
@@ -195,7 +195,8 @@ class MPICHImplTraits : public MPIImplTraits {
mtype = MPI_UINT8_T;
else
assert(false && "unsupported type");
- return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), mtype);
+ return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
+ mtype);
}
Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter,
@@ -245,7 +246,7 @@ class MPICHImplTraits : public MPIImplTraits {
op = MPI_REPLACE;
break;
}
- return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), op);
+ return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), op);
}
};
@@ -281,16 +282,16 @@ class OMPIImplTraits : public MPIImplTraits {
getOrDefineExternalStruct(loc, rewriter, name, commStructT);
// get address of symbol
- auto comm = rewriter.create<LLVM::AddressOfOp>(
- loc, LLVM::LLVMPointerType::get(context),
- SymbolRefAttr::get(context, name));
- return rewriter.create<LLVM::PtrToIntOp>(loc, rewriter.getI64Type(), comm);
+ auto comm = LLVM::AddressOfOp::create(rewriter, loc,
+ LLVM::LLVMPointerType::get(context),
+ SymbolRefAttr::get(context, name));
+ return LLVM::PtrToIntOp::create(rewriter, loc, rewriter.getI64Type(), comm);
}
Value castComm(const Location loc, ConversionPatternRewriter &rewriter,
Value comm) override {
- return rewriter.create<LLVM::IntToPtrOp>(
- loc, LLVM::LLVMPointerType::get(rewriter.getContext()), comm);
+ return LLVM::IntToPtrOp::create(
+ rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext()), comm);
}
intptr_t getStatusIgnore() override { return 0; }
@@ -330,9 +331,9 @@ class OMPIImplTraits : public MPIImplTraits {
// make sure global op definition exists
getOrDefineExternalStruct(loc, rewriter, mtype, typeStructT);
// get address of symbol
- return rewriter.create<LLVM::AddressOfOp>(
- loc, LLVM::LLVMPointerType::get(context),
- SymbolRefAttr::get(context, mtype));
+ return LLVM::AddressOfOp::create(rewriter, loc,
+ LLVM::LLVMPointerType::get(context),
+ SymbolRefAttr::get(context, mtype));
}
Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter,
@@ -389,9 +390,9 @@ class OMPIImplTraits : public MPIImplTraits {
// make sure global op definition exists
getOrDefineExternalStruct(loc, rewriter, op, opStructT);
// get address of symbol
- return rewriter.create<LLVM::AddressOfOp>(
- loc, LLVM::LLVMPointerType::get(context),
- SymbolRefAttr::get(context, op));
+ return LLVM::AddressOfOp::create(rewriter, loc,
+ LLVM::LLVMPointerType::get(context),
+ SymbolRefAttr::get(context, op));
}
};
@@ -424,7 +425,7 @@ struct InitOpLowering : public ConvertOpToLLVMPattern<mpi::InitOp> {
Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
// instantiate nullptr `%nullptr = llvm.mlir.zero : !llvm.ptr`
- auto nullPtrOp = rewriter.create<LLVM::ZeroOp>(loc, ptrType);
+ auto nullPtrOp = LLVM::ZeroOp::create(rewriter, loc, ptrType);
Value llvmnull = nullPtrOp.getRes();
// grab a reference to the global module op:
@@ -513,9 +514,9 @@ struct CommSplitOpLowering : public ConvertOpToLLVMPattern<mpi::CommSplitOp> {
// get communicator
Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
- auto one = rewriter.create<LLVM::ConstantOp>(loc, i32, 1);
+ auto one = LLVM::ConstantOp::create(rewriter, loc, i32, 1);
auto outPtr =
- rewriter.create<LLVM::AllocaOp>(loc, ptrType, comm.getType(), one);
+ LLVM::AllocaOp::create(rewriter, loc, ptrType, comm.getType(), one);
// int MPI_Comm_split(MPI_Comm comm, int color, int key, MPI_Comm * newcomm)
auto funcType =
@@ -524,14 +525,14 @@ struct CommSplitOpLowering : public ConvertOpToLLVMPattern<mpi::CommSplitOp> {
LLVM::LLVMFuncOp funcDecl = getOrDefineFunction(moduleOp, loc, rewriter,
"MPI_Comm_split", funcType);
- auto callOp = rewriter.create<LLVM::CallOp>(
- loc, funcDecl,
- ValueRange{comm, adaptor.getColor(), adaptor.getKey(),
- outPtr.getRes()});
+ auto callOp =
+ LLVM::CallOp::create(rewriter, loc, funcDecl,
+ ValueRange{comm, adaptor.getColor(),
+ adaptor.getKey(), outPtr.getRes()});
// load the communicator into a register
- Value res = rewriter.create<LLVM::LoadOp>(loc, i32, outPtr.getResult());
- res = rewriter.create<LLVM::SExtOp>(loc, rewriter.getI64Type(), res);
+ Value res = LLVM::LoadOp::create(rewriter, loc, i32, outPtr.getResult());
+ res = LLVM::SExtOp::create(rewriter, loc, rewriter.getI64Type(), res);
// if retval is checked, replace uses of retval with the results from the
// call op
@@ -580,14 +581,14 @@ struct CommRankOpLowering : public ConvertOpToLLVMPattern<mpi::CommRankOp> {
moduleOp, loc, rewriter, "MPI_Comm_rank", rankFuncType);
// replace with function call
- auto one = rewriter.create<LLVM::ConstantOp>(loc, i32, 1);
- auto rankptr = rewriter.create<LLVM::AllocaOp>(loc, ptrType, i32, one);
- auto callOp = rewriter.create<LLVM::CallOp>(
- loc, initDecl, ValueRange{comm, rankptr.getRes()});
+ auto one = LLVM::ConstantOp::create(rewriter, loc, i32, 1);
+ auto rankptr = LLVM::AllocaOp::create(rewriter, loc, ptrType, i32, one);
+ auto callOp = LLVM::CallOp::create(rewriter, loc, initDecl,
+ ValueRange{comm, rankptr.getRes()});
// load the rank into a register
auto loadedRank =
- rewriter.create<LLVM::LoadOp>(loc, i32, rankptr.getResult());
+ LLVM::LoadOp::create(rewriter, loc, i32, rankptr.getResult());
// if retval is checked, replace uses of retval with the results from the
// call op
@@ -641,10 +642,10 @@ struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Send", funcType);
// replace op with function call
- auto funcCall = rewriter.create<LLVM::CallOp>(
- loc, funcDecl,
- ValueRange{dataPtr, size, dataType, adaptor.getDest(), adaptor.getTag(),
- comm});
+ auto funcCall = LLVM::CallOp::create(rewriter, loc, funcDecl,
+ ValueRange{dataPtr, size, dataType,
+ adaptor.getDest(),
+ adaptor.getTag(), comm});
if (op.getRetval())
rewriter.replaceOp(op, funcCall.getResult());
else
@@ -683,10 +684,10 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
auto mpiTraits = MPIImplTraits::get(moduleOp);
Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
- Value statusIgnore = rewriter.create<LLVM::ConstantOp>(
- loc, i64, mpiTraits->getStatusIgnore());
+ Value statusIgnore = LLVM::ConstantOp::create(rewriter, loc, i64,
+ mpiTraits->getStatusIgnore());
statusIgnore =
- rewriter.create<LLVM::IntToPtrOp>(loc, ptrType, statusIgnore);
+ LLVM::IntToPtrOp::create(rewriter, loc, ptrType, statusIgnore);
// LLVM Function type representing `i32 MPI_Recv(data, count, datatype, dst,
// tag, comm)`
@@ -698,8 +699,8 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Recv", funcType);
// replace op with function call
- auto funcCall = rewriter.create<LLVM::CallOp>(
- loc, funcDecl,
+ auto funcCall = LLVM::CallOp::create(
+ rewriter, loc, funcDecl,
ValueRange{dataPtr, size, dataType, adaptor.getSource(),
adaptor.getTag(), comm, statusIgnore});
if (op.getRetval())
@@ -738,9 +739,10 @@ struct AllReduceOpLowering : public ConvertOpToLLVMPattern<mpi::AllReduceOp> {
// If input and output are the same, request in-place operation.
if (adaptor.getSendbuf() == adaptor.getRecvbuf()) {
- sendPtr = rewriter.create<LLVM::ConstantOp>(
- loc, i64, reinterpret_cast<int64_t>(mpiTraits->getInPlace()));
- sendPtr = rewriter.create<LLVM::IntToPtrOp>(loc, ptrType, sendPtr);
+ sendPtr = LLVM::ConstantOp::create(
+ rewriter, loc, i64,
+ reinterpret_cast<int64_t>(mpiTraits->getInPlace()));
+ sendPtr = LLVM::IntToPtrOp::create(rewriter, loc, ptrType, sendPtr);
}
Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
@@ -757,8 +759,8 @@ struct AllReduceOpLowering : public ConvertOpToLLVMPattern<mpi::AllReduceOp> {
getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Allreduce", funcType);
// replace op with function call
- auto funcCall = rewriter.create<LLVM::CallOp>(
- loc, funcDecl,
+ auto funcCall = LLVM::CallOp::create(
+ rewriter, loc, funcDecl,
ValueRange{sendPtr, recvPtr, sendSize, dataType, mpiOp, commWorld});
if (op.getRetval())
diff --git a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
index 7f4655e53609e..08a456691880c 100644
--- a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
+++ b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
@@ -121,19 +121,19 @@ VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
initValueAttr = FloatAttr::get(resultElementType, 0.0);
else
initValueAttr = IntegerAttr::get(resultElementType, 0);
- Value result = rewriter.create<arith::ConstantOp>(
- loc, DenseElementsAttr::get(vecType, initValueAttr));
+ Value result = arith::ConstantOp::create(
+ rewriter, loc, DenseElementsAttr::get(vecType, initValueAttr));
SmallVector<int64_t> strides = computeStrides(shape);
for (int64_t linearIndex = 0; linearIndex < numElements; ++linearIndex) {
SmallVector<int64_t> positions = delinearize(linearIndex, strides);
SmallVector<Value> operands;
for (Value input : op->getOperands())
operands.push_back(
- rewriter.create<vector::ExtractOp>(loc, input, positions));
+ vector::ExtractOp::create(rewriter, loc, input, positions));
Value scalarOp =
- rewriter.create<Op>(loc, vecType.getElementType(), operands);
+ Op::create(rewriter, loc, vecType.getElementType(), operands);
result =
- rewriter.create<vector::InsertOp>(loc, scalarOp, result, positions);
+ vector::InsertOp::create(rewriter, loc, scalarOp, result, positions);
}
rewriter.replaceOp(op, result);
return success();
@@ -195,7 +195,7 @@ static func::FuncOp createElementIPowIFunc(ModuleOp *module, Type elementType) {
FunctionType funcType = FunctionType::get(
builder.getContext(), {elementType, elementType}, elementType);
- auto funcOp = builder.create<func::FuncOp>(funcName, funcType);
+ auto funcOp = func::FuncOp::create(builder, funcName, funcType);
LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR;
Attribute linkage =
LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage);
@@ -208,12 +208,12 @@ static func::FuncOp createElementIPowIFunc(ModuleOp *module, Type elementType) {
Value bArg = funcOp.getArgument(0);
Value pArg = funcOp.getArgument(1);
builder.setInsertionPointToEnd(entryBlock);
- Value zeroValue = builder.create<arith::ConstantOp>(
- elementType, builder.getIntegerAttr(elementType, 0));
- Value oneValue = builder.create<arith::ConstantOp>(
- elementType, builder.getIntegerAttr(elementType, 1));
- Value minusOneValue = builder.create<arith::ConstantOp>(
- elementType,
+ Value zeroValue = arith::ConstantOp::create(
+ builder, elementType, builder.getIntegerAttr(elementType, 0));
+ Value oneValue = arith::ConstantOp::create(
+ builder, elementType, builder.getIntegerAttr(elementType, 1));
+ Value minusOneValue = arith::ConstantOp::create(
+ builder, elementType,
builder.getIntegerAttr(elementType,
APInt(elementType.getIntOrFloatBitWidth(), -1ULL,
/*isSigned=*/true)));
@@ -221,82 +221,83 @@ static func::FuncOp createElementIPowIFunc(ModuleOp *module, Type elementType) {
// if (p == T(0))
// return T(1);
auto pIsZero =
- builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, pArg, zeroValue);
+ arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, pArg, zeroValue);
Block *thenBlock = builder.createBlock(funcBody);
- builder.create<func::ReturnOp>(oneValue);
+ func::ReturnOp::create(builder, oneValue);
Block *fallthroughBlock = builder.createBlock(funcBody);
// Set up conditional branch for (p == T(0)).
builder.setInsertionPointToEnd(pIsZero->getBlock());
- builder.create<cf::CondBranchOp>(pIsZero, thenBlock, fallthroughBlock);
+ cf::CondBranchOp::create(builder, pIsZero, thenBlock, fallthroughBlock);
// if (p < T(0)) {
builder.setInsertionPointToEnd(fallthroughBlock);
- auto pIsNeg =
- builder.create<arith::CmpIOp>(arith::CmpIPredicate::sle, pArg, zeroValue);
+ auto pIsNeg = arith::CmpIOp::create(builder, arith::CmpIPredicate::sle, pArg,
+ zeroValue);
// if (b == T(0))
builder.createBlock(funcBody);
auto bIsZero =
- builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bArg, zeroValue);
+ arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, bArg, zeroValue);
// return T(1) / T(0);
thenBlock = builder.createBlock(funcBody);
- builder.create<func::ReturnOp>(
- builder.create<arith::DivSIOp>(oneValue, zeroValue).getResult());
+ func::ReturnOp::create(
+ builder,
+ arith::DivSIOp::create(builder, oneValue, zeroValue).getResult());
fallthroughBlock = builder.createBlock(funcBody);
// Set up conditional branch for (b == T(0)).
builder.setInsertionPointToEnd(bIsZero->getBlock());
- builder.create<cf::CondBranchOp>(bIsZero, thenBlock, fallthroughBlock);
+ cf::CondBranchOp::create(builder, bIsZero, thenBlock, fallthroughBlock);
// if (b == T(1))
builder.setInsertionPointToEnd(fallthroughBlock);
auto bIsOne =
- builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bArg, oneValue);
+ arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, bArg, oneValue);
// return T(1);
thenBlock = builder.createBlock(funcBody);
- builder.create<func::ReturnOp>(oneValue);
+ func::ReturnOp::create(builder, oneValue);
fallthroughBlock = builder.createBlock(funcBody);
// Set up conditional branch for (b == T(1)).
builder.setInsertionPointToEnd(bIsOne->getBlock());
- builder.create<cf::CondBranchOp>(bIsOne, thenBlock, fallthroughBlock);
+ cf::CondBranchOp::create(builder, bIsOne, thenBlock, fallthroughBlock);
// if (b == T(-1)) {
builder.setInsertionPointToEnd(fallthroughBlock);
- auto bIsMinusOne = builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
- bArg, minusOneValue);
+ auto bIsMinusOne = arith::CmpIOp::create(builder, arith::CmpIPredicate::eq,
+ bArg, minusOneValue);
// if (p & T(1))
builder.createBlock(funcBody);
- auto pIsOdd = builder.create<arith::CmpIOp>(
- arith::CmpIPredicate::ne, builder.create<arith::AndIOp>(pArg, oneValue),
- zeroValue);
+ auto pIsOdd = arith::CmpIOp::create(
+ builder, arith::CmpIPredicate::ne,
+ arith::AndIOp::create(builder, pArg, oneValue), zeroValue);
// return T(-1);
thenBlock = builder.createBlock(funcBody);
- builder.create<func::ReturnOp>(minusOneValue);
+ func::ReturnOp::create(builder, minusOneValue);
fallthroughBlock = builder.createBlock(funcBody);
// Set up conditional branch for (p & T(1)).
builder.setInsertionPointToEnd(pIsOdd->getBlock());
- builder.create<cf::CondBranchOp>(pIsOdd, thenBlock, fallthroughBlock);
+ cf::CondBranchOp::create(builder, pIsOdd, thenBlock, fallthroughBlock);
// return T(1);
// } // b == T(-1)
builder.setInsertionPointToEnd(fallthroughBlock);
- builder.create<func::ReturnOp>(oneValue);
+ func::ReturnOp::create(builder, oneValue);
fallthroughBlock = builder.createBlock(funcBody);
// Set up conditional branch for (b == T(-1)).
builder.setInsertionPointToEnd(bIsMinusOne->getBlock());
- builder.create<cf::CondBranchOp>(bIsMinusOne, pIsOdd->getBlock(),
- fallthroughBlock);
+ cf::CondBranchOp::create(builder, bIsMinusOne, pIsOdd->getBlock(),
+ fallthroughBlock);
// return T(0);
// } // (p < T(0))
builder.setInsertionPointToEnd(fallthroughBlock);
- builder.create<func::ReturnOp>(zeroValue);
+ func::ReturnOp::create(builder, zeroValue);
Block *loopHeader = builder.createBlock(
funcBody, funcBody->end(), {elementType, elementType, elementType},
{builder.getLoc(), builder.getLoc(), builder.getLoc()});
// Set up conditional branch for (p < T(0)).
builder.setInsertionPointToEnd(pIsNeg->getBlock());
// Set initial values of 'result', 'b' and 'p' for the loop.
- builder.create<cf::CondBranchOp>(pIsNeg, bIsZero->getBlock(), loopHeader,
- ValueRange{oneValue, bArg, pArg});
+ cf::CondBranchOp::create(builder, pIsNeg, bIsZero->getBlock(), loopHeader,
+ ValueRange{oneValue, bArg, pArg});
// T result = T(1);
// while (true) {
@@ -313,45 +314,46 @@ static func::FuncOp createElementIPowIFunc(ModuleOp *module, Type elementType) {
builder.setInsertionPointToEnd(loopHeader);
// if (p & T(1))
- auto powerTmpIsOdd = builder.create<arith::CmpIOp>(
- arith::CmpIPredicate::ne,
- builder.create<arith::AndIOp>(powerTmp, oneValue), zeroValue);
+ auto powerTmpIsOdd = arith::CmpIOp::create(
+ builder, arith::CmpIPredicate::ne,
+ arith::AndIOp::create(builder, powerTmp, oneValue), zeroValue);
thenBlock = builder.createBlock(funcBody);
// result *= b;
- Value newResultTmp = builder.create<arith::MulIOp>(resultTmp, baseTmp);
+ Value newResultTmp = arith::MulIOp::create(builder, resultTmp, baseTmp);
fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), elementType,
builder.getLoc());
builder.setInsertionPointToEnd(thenBlock);
- builder.create<cf::BranchOp>(newResultTmp, fallthroughBlock);
+ cf::BranchOp::create(builder, newResultTmp, fallthroughBlock);
// Set up conditional branch for (p & T(1)).
builder.setInsertionPointToEnd(powerTmpIsOdd->getBlock());
- builder.create<cf::CondBranchOp>(powerTmpIsOdd, thenBlock, fallthroughBlock,
- resultTmp);
+ cf::CondBranchOp::create(builder, powerTmpIsOdd, thenBlock, fallthroughBlock,
+ resultTmp);
// Merged 'result'.
newResultTmp = fallthroughBlock->getArgument(0);
// p >>= T(1);
builder.setInsertionPointToEnd(fallthroughBlock);
- Value newPowerTmp = builder.create<arith::ShRUIOp>(powerTmp, oneValue);
+ Value newPowerTmp = arith::ShRUIOp::create(builder, powerTmp, oneValue);
// if (p == T(0))
- auto newPowerIsZero = builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
- newPowerTmp, zeroValue);
+ auto newPowerIsZero = arith::CmpIOp::create(builder, arith::CmpIPredicate::eq,
+ newPowerTmp, zeroValue);
// return result;
thenBlock = builder.createBlock(funcBody);
- builder.create<func::ReturnOp>(newResultTmp);
+ func::ReturnOp::create(builder, newResultTmp);
fallthroughBlock = builder.createBlock(funcBody);
// Set up conditional branch for (p == T(0)).
builder.setInsertionPointToEnd(newPowerIsZero->getBlock());
- builder.create<cf::CondBranchOp>(newPowerIsZero, thenBlock, fallthroughBlock);
+ cf::CondBranchOp::create(builder, newPowerIsZero, thenBlock,
+ fallthroughBlock);
// b *= b;
// }
builder.setInsertionPointToEnd(fallthroughBlock);
- Value newBaseTmp = builder.create<arith::MulIOp>(baseTmp, baseTmp);
+ Value newBaseTmp = arith::MulIOp::create(builder, baseTmp, baseTmp);
// Pass new values for 'result', 'b' and 'p' to the loop header.
- builder.create<cf::BranchOp>(
- ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader);
+ cf::BranchOp::create(
+ builder, ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader);
return funcOp;
}
@@ -420,7 +422,7 @@ static func::FuncOp createElementFPowIFunc(ModuleOp *module,
llvm::raw_string_ostream nameOS(funcName);
nameOS << '_' << baseType;
nameOS << '_' << powType;
- auto funcOp = builder.create<func::FuncOp>(funcName, funcType);
+ auto funcOp = func::FuncOp::create(builder, funcName, funcType);
LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR;
Attribute linkage =
LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage);
@@ -433,46 +435,48 @@ static func::FuncOp createElementFPowIFunc(ModuleOp *module,
Value bArg = funcOp.getArgument(0);
Value pArg = funcOp.getArgument(1);
builder.setInsertionPointToEnd(entryBlock);
- Value oneBValue = builder.create<arith::ConstantOp>(
- baseType, builder.getFloatAttr(baseType, 1.0));
- Value zeroPValue = builder.create<arith::ConstantOp>(
- powType, builder.getIntegerAttr(powType, 0));
- Value onePValue = builder.create<arith::ConstantOp>(
- powType, builder.getIntegerAttr(powType, 1));
- Value minPValue = builder.create<arith::ConstantOp>(
- powType, builder.getIntegerAttr(powType, llvm::APInt::getSignedMinValue(
- powType.getWidth())));
- Value maxPValue = builder.create<arith::ConstantOp>(
- powType, builder.getIntegerAttr(powType, llvm::APInt::getSignedMaxValue(
- powType.getWidth())));
+ Value oneBValue = arith::ConstantOp::create(
+ builder, baseType, builder.getFloatAttr(baseType, 1.0));
+ Value zeroPValue = arith::ConstantOp::create(
+ builder, powType, builder.getIntegerAttr(powType, 0));
+ Value onePValue = arith::ConstantOp::create(
+ builder, powType, builder.getIntegerAttr(powType, 1));
+ Value minPValue = arith::ConstantOp::create(
+ builder, powType,
+ builder.getIntegerAttr(
+ powType, llvm::APInt::getSignedMinValue(powType.getWidth())));
+ Value maxPValue = arith::ConstantOp::create(
+ builder, powType,
+ builder.getIntegerAttr(
+ powType, llvm::APInt::getSignedMaxValue(powType.getWidth())));
// if (p == Tp{0})
// return Tb{1};
- auto pIsZero =
- builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, pArg, zeroPValue);
+ auto pIsZero = arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, pArg,
+ zeroPValue);
Block *thenBlock = builder.createBlock(funcBody);
- builder.create<func::ReturnOp>(oneBValue);
+ func::ReturnOp::create(builder, oneBValue);
Block *fallthroughBlock = builder.createBlock(funcBody);
// Set up conditional branch for (p == Tp{0}).
builder.setInsertionPointToEnd(pIsZero->getBlock());
- builder.create<cf::CondBranchOp>(pIsZero, thenBlock, fallthroughBlock);
+ cf::CondBranchOp::create(builder, pIsZero, thenBlock, fallthroughBlock);
builder.setInsertionPointToEnd(fallthroughBlock);
// bool isNegativePower{p < Tp{0}}
- auto pIsNeg = builder.create<arith::CmpIOp>(arith::CmpIPredicate::sle, pArg,
- zeroPValue);
+ auto pIsNeg = arith::CmpIOp::create(builder, arith::CmpIPredicate::sle, pArg,
+ zeroPValue);
// bool isMin{p == std::numeric_limits<Tp>::min()};
auto pIsMin =
- builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, pArg, minPValue);
+ arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, pArg, minPValue);
// if (isMin) {
// p = std::numeric_limits<Tp>::max();
// } else if (isNegativePower) {
// p = -p;
// }
- Value negP = builder.create<arith::SubIOp>(zeroPValue, pArg);
- auto pInit = builder.create<arith::SelectOp>(pIsNeg, negP, pArg);
- pInit = builder.create<arith::SelectOp>(pIsMin, maxPValue, pInit);
+ Value negP = arith::SubIOp::create(builder, zeroPValue, pArg);
+ auto pInit = arith::SelectOp::create(builder, pIsNeg, negP, pArg);
+ pInit = arith::SelectOp::create(builder, pIsMin, maxPValue, pInit);
// Tb result = Tb{1};
// Tb origBase = Tb{b};
@@ -489,7 +493,7 @@ static func::FuncOp createElementFPowIFunc(ModuleOp *module,
{builder.getLoc(), builder.getLoc(), builder.getLoc()});
// Set initial values of 'result', 'b' and 'p' for the loop.
builder.setInsertionPointToEnd(pInit->getBlock());
- builder.create<cf::BranchOp>(loopHeader, ValueRange{oneBValue, bArg, pInit});
+ cf::BranchOp::create(builder, loopHeader, ValueRange{oneBValue, bArg, pInit});
// Create loop body.
Value resultTmp = loopHeader->getArgument(0);
@@ -498,30 +502,30 @@ static func::FuncOp createElementFPowIFunc(ModuleOp *module,
builder.setInsertionPointToEnd(loopHeader);
// if (p & Tp{1})
- auto powerTmpIsOdd = builder.create<arith::CmpIOp>(
- arith::CmpIPredicate::ne,
- builder.create<arith::AndIOp>(powerTmp, onePValue), zeroPValue);
+ auto powerTmpIsOdd = arith::CmpIOp::create(
+ builder, arith::CmpIPredicate::ne,
+ arith::AndIOp::create(builder, powerTmp, onePValue), zeroPValue);
thenBlock = builder.createBlock(funcBody);
// result *= b;
- Value newResultTmp = builder.create<arith::MulFOp>(resultTmp, baseTmp);
+ Value newResultTmp = arith::MulFOp::create(builder, resultTmp, baseTmp);
fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), baseType,
builder.getLoc());
builder.setInsertionPointToEnd(thenBlock);
- builder.create<cf::BranchOp>(newResultTmp, fallthroughBlock);
+ cf::BranchOp::create(builder, newResultTmp, fallthroughBlock);
// Set up conditional branch for (p & Tp{1}).
builder.setInsertionPointToEnd(powerTmpIsOdd->getBlock());
- builder.create<cf::CondBranchOp>(powerTmpIsOdd, thenBlock, fallthroughBlock,
- resultTmp);
+ cf::CondBranchOp::create(builder, powerTmpIsOdd, thenBlock, fallthroughBlock,
+ resultTmp);
// Merged 'result'.
newResultTmp = fallthroughBlock->getArgument(0);
// p >>= Tp{1};
builder.setInsertionPointToEnd(fallthroughBlock);
- Value newPowerTmp = builder.create<arith::ShRUIOp>(powerTmp, onePValue);
+ Value newPowerTmp = arith::ShRUIOp::create(builder, powerTmp, onePValue);
// if (p == Tp{0})
- auto newPowerIsZero = builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
- newPowerTmp, zeroPValue);
+ auto newPowerIsZero = arith::CmpIOp::create(builder, arith::CmpIPredicate::eq,
+ newPowerTmp, zeroPValue);
// break;
//
// The conditional branch is finalized below with a jump to
@@ -531,10 +535,10 @@ static func::FuncOp createElementFPowIFunc(ModuleOp *module,
// b *= b;
// }
builder.setInsertionPointToEnd(fallthroughBlock);
- Value newBaseTmp = builder.create<arith::MulFOp>(baseTmp, baseTmp);
+ Value newBaseTmp = arith::MulFOp::create(builder, baseTmp, baseTmp);
// Pass new values for 'result', 'b' and 'p' to the loop header.
- builder.create<cf::BranchOp>(
- ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader);
+ cf::BranchOp::create(
+ builder, ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader);
// Set up conditional branch for early loop exit:
// if (p == Tp{0})
@@ -542,8 +546,8 @@ static func::FuncOp createElementFPowIFunc(ModuleOp *module,
Block *loopExit = builder.createBlock(funcBody, funcBody->end(), baseType,
builder.getLoc());
builder.setInsertionPointToEnd(newPowerIsZero->getBlock());
- builder.create<cf::CondBranchOp>(newPowerIsZero, loopExit, newResultTmp,
- fallthroughBlock, ValueRange{});
+ cf::CondBranchOp::create(builder, newPowerIsZero, loopExit, newResultTmp,
+ fallthroughBlock, ValueRange{});
// if (isMin) {
// result *= origBase;
@@ -553,11 +557,11 @@ static func::FuncOp createElementFPowIFunc(ModuleOp *module,
fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), baseType,
builder.getLoc());
builder.setInsertionPointToEnd(loopExit);
- builder.create<cf::CondBranchOp>(pIsMin, thenBlock, fallthroughBlock,
- newResultTmp);
+ cf::CondBranchOp::create(builder, pIsMin, thenBlock, fallthroughBlock,
+ newResultTmp);
builder.setInsertionPointToEnd(thenBlock);
- newResultTmp = builder.create<arith::MulFOp>(newResultTmp, bArg);
- builder.create<cf::BranchOp>(newResultTmp, fallthroughBlock);
+ newResultTmp = arith::MulFOp::create(builder, newResultTmp, bArg);
+ cf::BranchOp::create(builder, newResultTmp, fallthroughBlock);
/// if (isNegativePower) {
/// result = Tb{1} / result;
@@ -567,15 +571,15 @@ static func::FuncOp createElementFPowIFunc(ModuleOp *module,
Block *returnBlock = builder.createBlock(funcBody, funcBody->end(), baseType,
builder.getLoc());
builder.setInsertionPointToEnd(fallthroughBlock);
- builder.create<cf::CondBranchOp>(pIsNeg, thenBlock, returnBlock,
- newResultTmp);
+ cf::CondBranchOp::create(builder, pIsNeg, thenBlock, returnBlock,
+ newResultTmp);
builder.setInsertionPointToEnd(thenBlock);
- newResultTmp = builder.create<arith::DivFOp>(oneBValue, newResultTmp);
- builder.create<cf::BranchOp>(newResultTmp, returnBlock);
+ newResultTmp = arith::DivFOp::create(builder, oneBValue, newResultTmp);
+ cf::BranchOp::create(builder, newResultTmp, returnBlock);
// return result;
builder.setInsertionPointToEnd(returnBlock);
- builder.create<func::ReturnOp>(returnBlock->getArgument(0));
+ func::ReturnOp::create(builder, returnBlock->getArgument(0));
return funcOp;
}
@@ -667,7 +671,7 @@ static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType) {
nameOS << '_' << elementType;
FunctionType funcType =
FunctionType::get(builder.getContext(), {elementType}, elementType);
- auto funcOp = builder.create<func::FuncOp>(funcName, funcType);
+ auto funcOp = func::FuncOp::create(builder, funcName, funcType);
// LinkonceODR ensures that there is only one implementation of this function
// across all math.ctlz functions that are lowered in this way.
@@ -683,33 +687,34 @@ static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType) {
Value arg = funcOp.getArgument(0);
Type indexType = builder.getIndexType();
- Value bitWidthValue = builder.create<arith::ConstantOp>(
- elementType, builder.getIntegerAttr(elementType, bitWidth));
- Value zeroValue = builder.create<arith::ConstantOp>(
- elementType, builder.getIntegerAttr(elementType, 0));
+ Value bitWidthValue = arith::ConstantOp::create(
+ builder, elementType, builder.getIntegerAttr(elementType, bitWidth));
+ Value zeroValue = arith::ConstantOp::create(
+ builder, elementType, builder.getIntegerAttr(elementType, 0));
Value inputEqZero =
- builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, arg, zeroValue);
+ arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, arg, zeroValue);
// if input == 0, return bit width, else enter loop.
- scf::IfOp ifOp = builder.create<scf::IfOp>(
- elementType, inputEqZero, /*addThenBlock=*/true, /*addElseBlock=*/true);
+ scf::IfOp ifOp =
+ scf::IfOp::create(builder, elementType, inputEqZero,
+ /*addThenBlock=*/true, /*addElseBlock=*/true);
ifOp.getThenBodyBuilder().create<scf::YieldOp>(loc, bitWidthValue);
auto elseBuilder =
ImplicitLocOpBuilder::atBlockEnd(loc, &ifOp.getElseRegion().front());
- Value oneIndex = elseBuilder.create<arith::ConstantOp>(
- indexType, elseBuilder.getIndexAttr(1));
- Value oneValue = elseBuilder.create<arith::ConstantOp>(
- elementType, elseBuilder.getIntegerAttr(elementType, 1));
- Value bitWidthIndex = elseBuilder.create<arith::ConstantOp>(
- indexType, elseBuilder.getIndexAttr(bitWidth));
- Value nValue = elseBuilder.create<arith::ConstantOp>(
- elementType, elseBuilder.getIntegerAttr(elementType, 0));
-
- auto loop = elseBuilder.create<scf::ForOp>(
- oneIndex, bitWidthIndex, oneIndex,
+ Value oneIndex = arith::ConstantOp::create(elseBuilder, indexType,
+ elseBuilder.getIndexAttr(1));
+ Value oneValue = arith::ConstantOp::create(
+ elseBuilder, elementType, elseBuilder.getIntegerAttr(elementType, 1));
+ Value bitWidthIndex = arith::ConstantOp::create(
+ elseBuilder, indexType, elseBuilder.getIndexAttr(bitWidth));
+ Value nValue = arith::ConstantOp::create(
+ elseBuilder, elementType, elseBuilder.getIntegerAttr(elementType, 0));
+
+ auto loop = scf::ForOp::create(
+ elseBuilder, oneIndex, bitWidthIndex, oneIndex,
// Initial values for two loop induction variables, the arg which is being
// shifted left in each iteration, and the n value which tracks the count
// of leading zeros.
@@ -725,25 +730,25 @@ static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType) {
Value argIter = args[0];
Value nIter = args[1];
- Value argIsNonNegative = b.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::slt, argIter, zeroValue);
- scf::IfOp ifOp = b.create<scf::IfOp>(
- loc, argIsNonNegative,
+ Value argIsNonNegative = arith::CmpIOp::create(
+ b, loc, arith::CmpIPredicate::slt, argIter, zeroValue);
+ scf::IfOp ifOp = scf::IfOp::create(
+ b, loc, argIsNonNegative,
[&](OpBuilder &b, Location loc) {
// If arg is negative, continue (effectively, break)
- b.create<scf::YieldOp>(loc, ValueRange{argIter, nIter});
+ scf::YieldOp::create(b, loc, ValueRange{argIter, nIter});
},
[&](OpBuilder &b, Location loc) {
// Otherwise, increment n and shift arg left.
- Value nNext = b.create<arith::AddIOp>(loc, nIter, oneValue);
- Value argNext = b.create<arith::ShLIOp>(loc, argIter, oneValue);
- b.create<scf::YieldOp>(loc, ValueRange{argNext, nNext});
+ Value nNext = arith::AddIOp::create(b, loc, nIter, oneValue);
+ Value argNext = arith::ShLIOp::create(b, loc, argIter, oneValue);
+ scf::YieldOp::create(b, loc, ValueRange{argNext, nNext});
});
- b.create<scf::YieldOp>(loc, ifOp.getResults());
+ scf::YieldOp::create(b, loc, ifOp.getResults());
});
- elseBuilder.create<scf::YieldOp>(loop.getResult(1));
+ scf::YieldOp::create(elseBuilder, loop.getResult(1));
- builder.create<func::ReturnOp>(ifOp.getResult(0));
+ func::ReturnOp::create(builder, ifOp.getResult(0));
return funcOp;
}
diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index f4d69ce8235bb..853f45498ac52 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -107,8 +107,8 @@ struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern<MathOp> {
return LLVM::detail::handleMultidimensionalVectors(
op.getOperation(), adaptor.getOperands(), typeConverter,
[&](Type llvm1DVectorTy, ValueRange operands) {
- return rewriter.create<LLVMOp>(loc, llvm1DVectorTy, operands[0],
- false);
+ return LLVMOp::create(rewriter, loc, llvm1DVectorTy, operands[0],
+ false);
},
rewriter);
}
@@ -145,15 +145,16 @@ struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
LLVM::ConstantOp one;
if (LLVM::isCompatibleVectorType(llvmOperandType)) {
- one = rewriter.create<LLVM::ConstantOp>(
- loc, llvmOperandType,
+ one = LLVM::ConstantOp::create(
+ rewriter, loc, llvmOperandType,
SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
floatOne));
} else {
- one = rewriter.create<LLVM::ConstantOp>(loc, llvmOperandType, floatOne);
+ one =
+ LLVM::ConstantOp::create(rewriter, loc, llvmOperandType, floatOne);
}
- auto exp = rewriter.create<LLVM::ExpOp>(loc, adaptor.getOperand(),
- expAttrs.getAttrs());
+ auto exp = LLVM::ExpOp::create(rewriter, loc, adaptor.getOperand(),
+ expAttrs.getAttrs());
rewriter.replaceOpWithNewOp<LLVM::FSubOp>(
op, llvmOperandType, ValueRange{exp, one}, subAttrs.getAttrs());
return success();
@@ -170,12 +171,13 @@ struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
{numElements.isScalable()}),
floatOne);
- auto one =
- rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
- auto exp = rewriter.create<LLVM::ExpOp>(
- loc, llvm1DVectorTy, operands[0], expAttrs.getAttrs());
- return rewriter.create<LLVM::FSubOp>(
- loc, llvm1DVectorTy, ValueRange{exp, one}, subAttrs.getAttrs());
+ auto one = LLVM::ConstantOp::create(rewriter, loc, llvm1DVectorTy,
+ splatAttr);
+ auto exp = LLVM::ExpOp::create(rewriter, loc, llvm1DVectorTy,
+ operands[0], expAttrs.getAttrs());
+ return LLVM::FSubOp::create(rewriter, loc, llvm1DVectorTy,
+ ValueRange{exp, one},
+ subAttrs.getAttrs());
},
rewriter);
}
@@ -205,16 +207,16 @@ struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
LLVM::ConstantOp one =
isa<VectorType>(llvmOperandType)
- ? rewriter.create<LLVM::ConstantOp>(
- loc, llvmOperandType,
+ ? LLVM::ConstantOp::create(
+ rewriter, loc, llvmOperandType,
SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
floatOne))
- : rewriter.create<LLVM::ConstantOp>(loc, llvmOperandType,
- floatOne);
+ : LLVM::ConstantOp::create(rewriter, loc, llvmOperandType,
+ floatOne);
- auto add = rewriter.create<LLVM::FAddOp>(
- loc, llvmOperandType, ValueRange{one, adaptor.getOperand()},
- addAttrs.getAttrs());
+ auto add = LLVM::FAddOp::create(rewriter, loc, llvmOperandType,
+ ValueRange{one, adaptor.getOperand()},
+ addAttrs.getAttrs());
rewriter.replaceOpWithNewOp<LLVM::LogOp>(
op, llvmOperandType, ValueRange{add}, logAttrs.getAttrs());
return success();
@@ -231,13 +233,13 @@ struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
{numElements.isScalable()}),
floatOne);
- auto one =
- rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
- auto add = rewriter.create<LLVM::FAddOp>(loc, llvm1DVectorTy,
- ValueRange{one, operands[0]},
- addAttrs.getAttrs());
- return rewriter.create<LLVM::LogOp>(
- loc, llvm1DVectorTy, ValueRange{add}, logAttrs.getAttrs());
+ auto one = LLVM::ConstantOp::create(rewriter, loc, llvm1DVectorTy,
+ splatAttr);
+ auto add = LLVM::FAddOp::create(rewriter, loc, llvm1DVectorTy,
+ ValueRange{one, operands[0]},
+ addAttrs.getAttrs());
+ return LLVM::LogOp::create(rewriter, loc, llvm1DVectorTy,
+ ValueRange{add}, logAttrs.getAttrs());
},
rewriter);
}
@@ -267,15 +269,16 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
LLVM::ConstantOp one;
if (isa<VectorType>(llvmOperandType)) {
- one = rewriter.create<LLVM::ConstantOp>(
- loc, llvmOperandType,
+ one = LLVM::ConstantOp::create(
+ rewriter, loc, llvmOperandType,
SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
floatOne));
} else {
- one = rewriter.create<LLVM::ConstantOp>(loc, llvmOperandType, floatOne);
+ one =
+ LLVM::ConstantOp::create(rewriter, loc, llvmOperandType, floatOne);
}
- auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, adaptor.getOperand(),
- sqrtAttrs.getAttrs());
+ auto sqrt = LLVM::SqrtOp::create(rewriter, loc, adaptor.getOperand(),
+ sqrtAttrs.getAttrs());
rewriter.replaceOpWithNewOp<LLVM::FDivOp>(
op, llvmOperandType, ValueRange{one, sqrt}, divAttrs.getAttrs());
return success();
@@ -292,12 +295,13 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
{numElements.isScalable()}),
floatOne);
- auto one =
- rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
- auto sqrt = rewriter.create<LLVM::SqrtOp>(
- loc, llvm1DVectorTy, operands[0], sqrtAttrs.getAttrs());
- return rewriter.create<LLVM::FDivOp>(
- loc, llvm1DVectorTy, ValueRange{one, sqrt}, divAttrs.getAttrs());
+ auto one = LLVM::ConstantOp::create(rewriter, loc, llvm1DVectorTy,
+ splatAttr);
+ auto sqrt = LLVM::SqrtOp::create(rewriter, loc, llvm1DVectorTy,
+ operands[0], sqrtAttrs.getAttrs());
+ return LLVM::FDivOp::create(rewriter, loc, llvm1DVectorTy,
+ ValueRange{one, sqrt},
+ divAttrs.getAttrs());
},
rewriter);
}
diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
index a0ce7d3b75fc2..f7c0d4fe3a799 100644
--- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
+++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
@@ -84,20 +84,21 @@ VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
auto shape = vecType.getShape();
int64_t numElements = vecType.getNumElements();
- Value result = rewriter.create<arith::ConstantOp>(
- loc, DenseElementsAttr::get(
- vecType, FloatAttr::get(vecType.getElementType(), 0.0)));
+ Value result = arith::ConstantOp::create(
+ rewriter, loc,
+ DenseElementsAttr::get(vecType,
+ FloatAttr::get(vecType.getElementType(), 0.0)));
SmallVector<int64_t> strides = computeStrides(shape);
for (auto linearIndex = 0; linearIndex < numElements; ++linearIndex) {
SmallVector<int64_t> positions = delinearize(linearIndex, strides);
SmallVector<Value> operands;
for (auto input : op->getOperands())
operands.push_back(
- rewriter.create<vector::ExtractOp>(loc, input, positions));
+ vector::ExtractOp::create(rewriter, loc, input, positions));
Value scalarOp =
- rewriter.create<Op>(loc, vecType.getElementType(), operands);
+ Op::create(rewriter, loc, vecType.getElementType(), operands);
result =
- rewriter.create<vector::InsertOp>(loc, scalarOp, result, positions);
+ vector::InsertOp::create(rewriter, loc, scalarOp, result, positions);
}
rewriter.replaceOp(op, {result});
return success();
@@ -114,9 +115,9 @@ PromoteOpToF32<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
auto f32 = rewriter.getF32Type();
auto extendedOperands = llvm::to_vector(
llvm::map_range(op->getOperands(), [&](Value operand) -> Value {
- return rewriter.create<arith::ExtFOp>(loc, f32, operand);
+ return arith::ExtFOp::create(rewriter, loc, f32, operand);
}));
- auto newOp = rewriter.create<Op>(loc, f32, extendedOperands);
+ auto newOp = Op::create(rewriter, loc, f32, extendedOperands);
rewriter.replaceOpWithNewOp<arith::TruncFOp>(op, opType, newOp);
return success();
}
@@ -139,8 +140,8 @@ ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
rewriter.setInsertionPointToStart(&module->getRegion(0).front());
auto opFunctionTy = FunctionType::get(
rewriter.getContext(), op->getOperandTypes(), op->getResultTypes());
- opFunc = rewriter.create<func::FuncOp>(rewriter.getUnknownLoc(), name,
- opFunctionTy);
+ opFunc = func::FuncOp::create(rewriter, rewriter.getUnknownLoc(), name,
+ opFunctionTy);
opFunc.setPrivate();
// By definition Math dialect operations imply LLVM's "readnone"
diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index 59db14ed816be..a877ad21734a2 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -36,12 +36,12 @@ static Value getScalarOrVectorI32Constant(Type type, int value,
if (!vectorType.getElementType().isInteger(32))
return nullptr;
SmallVector<int> values(vectorType.getNumElements(), value);
- return builder.create<spirv::ConstantOp>(loc, type,
- builder.getI32VectorAttr(values));
+ return spirv::ConstantOp::create(builder, loc, type,
+ builder.getI32VectorAttr(values));
}
if (type.isInteger(32))
- return builder.create<spirv::ConstantOp>(loc, type,
- builder.getI32IntegerAttr(value));
+ return spirv::ConstantOp::create(builder, loc, type,
+ builder.getI32IntegerAttr(value));
return nullptr;
}
@@ -144,10 +144,11 @@ struct CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
Type intType = rewriter.getIntegerType(bitwidth);
uint64_t intValue = uint64_t(1) << (bitwidth - 1);
- Value signMask = rewriter.create<spirv::ConstantOp>(
- loc, intType, rewriter.getIntegerAttr(intType, intValue));
- Value valueMask = rewriter.create<spirv::ConstantOp>(
- loc, intType, rewriter.getIntegerAttr(intType, intValue - 1u));
+ Value signMask = spirv::ConstantOp::create(
+ rewriter, loc, intType, rewriter.getIntegerAttr(intType, intValue));
+ Value valueMask = spirv::ConstantOp::create(
+ rewriter, loc, intType,
+ rewriter.getIntegerAttr(intType, intValue - 1u));
if (auto vectorType = dyn_cast<VectorType>(type)) {
assert(vectorType.getRank() == 1);
@@ -155,26 +156,26 @@ struct CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
intType = VectorType::get(count, intType);
SmallVector<Value> signSplat(count, signMask);
- signMask =
- rewriter.create<spirv::CompositeConstructOp>(loc, intType, signSplat);
+ signMask = spirv::CompositeConstructOp::create(rewriter, loc, intType,
+ signSplat);
SmallVector<Value> valueSplat(count, valueMask);
- valueMask = rewriter.create<spirv::CompositeConstructOp>(loc, intType,
- valueSplat);
+ valueMask = spirv::CompositeConstructOp::create(rewriter, loc, intType,
+ valueSplat);
}
Value lhsCast =
- rewriter.create<spirv::BitcastOp>(loc, intType, adaptor.getLhs());
+ spirv::BitcastOp::create(rewriter, loc, intType, adaptor.getLhs());
Value rhsCast =
- rewriter.create<spirv::BitcastOp>(loc, intType, adaptor.getRhs());
+ spirv::BitcastOp::create(rewriter, loc, intType, adaptor.getRhs());
- Value value = rewriter.create<spirv::BitwiseAndOp>(
- loc, intType, ValueRange{lhsCast, valueMask});
- Value sign = rewriter.create<spirv::BitwiseAndOp>(
- loc, intType, ValueRange{rhsCast, signMask});
+ Value value = spirv::BitwiseAndOp::create(rewriter, loc, intType,
+ ValueRange{lhsCast, valueMask});
+ Value sign = spirv::BitwiseAndOp::create(rewriter, loc, intType,
+ ValueRange{rhsCast, signMask});
- Value result = rewriter.create<spirv::BitwiseOrOp>(loc, intType,
- ValueRange{value, sign});
+ Value result = spirv::BitwiseOrOp::create(rewriter, loc, intType,
+ ValueRange{value, sign});
rewriter.replaceOpWithNewOp<spirv::BitcastOp>(copySignOp, type, result);
return success();
}
@@ -214,18 +215,18 @@ struct CountLeadingZerosPattern final
Value val31 = getScalarOrVectorI32Constant(type, 31, rewriter, loc);
Value val32 = getScalarOrVectorI32Constant(type, 32, rewriter, loc);
- Value msb = rewriter.create<spirv::GLFindUMsbOp>(loc, input);
+ Value msb = spirv::GLFindUMsbOp::create(rewriter, loc, input);
// We need to subtract from 31 given that the index returned by GLSL
// FindUMsb is counted from the least significant bit. Theoretically this
// also gives the correct result even if the integer has all zero bits, in
// which case GL FindUMsb would return -1.
- Value subMsb = rewriter.create<spirv::ISubOp>(loc, val31, msb);
+ Value subMsb = spirv::ISubOp::create(rewriter, loc, val31, msb);
// However, certain Vulkan implementations have driver bugs for the corner
// case where the input is zero. And.. it can be smart to optimize a select
// only involving the corner case. So separately compute the result when the
// input is either zero or one.
- Value subInput = rewriter.create<spirv::ISubOp>(loc, val32, input);
- Value cmp = rewriter.create<spirv::ULessThanEqualOp>(loc, input, val1);
+ Value subInput = spirv::ISubOp::create(rewriter, loc, val32, input);
+ Value cmp = spirv::ULessThanEqualOp::create(rewriter, loc, input, val1);
rewriter.replaceOpWithNewOp<spirv::SelectOp>(countOp, cmp, subInput,
subMsb);
return success();
@@ -253,7 +254,7 @@ struct ExpM1OpPattern final : public OpConversionPattern<math::ExpM1Op> {
if (!type)
return failure();
- Value exp = rewriter.create<ExpOp>(loc, type, adaptor.getOperand());
+ Value exp = ExpOp::create(rewriter, loc, type, adaptor.getOperand());
auto one = spirv::ConstantOp::getOne(type, loc, rewriter);
rewriter.replaceOpWithNewOp<spirv::FSubOp>(operation, exp, one);
return success();
@@ -283,7 +284,7 @@ struct Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> {
auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter);
Value onePlus =
- rewriter.create<spirv::FAddOp>(loc, one, adaptor.getOperand());
+ spirv::FAddOp::create(rewriter, loc, one, adaptor.getOperand());
rewriter.replaceOpWithNewOp<LogOp>(operation, type, onePlus);
return success();
}
@@ -321,15 +322,15 @@ struct Log2Log10OpPattern final : public OpConversionPattern<MathLogOp> {
auto getConstantValue = [&](double value) {
if (auto floatType = dyn_cast<FloatType>(type)) {
- return rewriter.create<spirv::ConstantOp>(
- loc, type, rewriter.getFloatAttr(floatType, value));
+ return spirv::ConstantOp::create(
+ rewriter, loc, type, rewriter.getFloatAttr(floatType, value));
}
if (auto vectorType = dyn_cast<VectorType>(type)) {
Type elemType = vectorType.getElementType();
if (isa<FloatType>(elemType)) {
- return rewriter.create<spirv::ConstantOp>(
- loc, type,
+ return spirv::ConstantOp::create(
+ rewriter, loc, type,
DenseFPElementsAttr::get(
vectorType, FloatAttr::get(elemType, value).getValue()));
}
@@ -341,7 +342,7 @@ struct Log2Log10OpPattern final : public OpConversionPattern<MathLogOp> {
Value constantValue = getConstantValue(
std::is_same<MathLogOp, math::Log2Op>() ? log2Reciprocal
: log10Reciprocal);
- Value log = rewriter.create<SpirvLogOp>(loc, adaptor.getOperand());
+ Value log = SpirvLogOp::create(rewriter, loc, adaptor.getOperand());
rewriter.replaceOpWithNewOp<spirv::FMulOp>(operation, type, log,
constantValue);
return success();
@@ -386,7 +387,7 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
Location loc = powfOp.getLoc();
Value zero = spirv::ConstantOp::getZero(operandType, loc, rewriter);
Value lessThan =
- rewriter.create<spirv::FOrdLessThanOp>(loc, adaptor.getLhs(), zero);
+ spirv::FOrdLessThanOp::create(rewriter, loc, adaptor.getLhs(), zero);
// Per C/C++ spec:
// > pow(base, exponent) returns NaN (and raises FE_INVALID) if base is
@@ -394,11 +395,11 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
// Calculate the reminder from the exponent and check whether it is zero.
Value floatOne = spirv::ConstantOp::getOne(operandType, loc, rewriter);
Value expRem =
- rewriter.create<spirv::FRemOp>(loc, adaptor.getRhs(), floatOne);
+ spirv::FRemOp::create(rewriter, loc, adaptor.getRhs(), floatOne);
Value expRemNonZero =
- rewriter.create<spirv::FOrdNotEqualOp>(loc, expRem, zero);
+ spirv::FOrdNotEqualOp::create(rewriter, loc, expRem, zero);
Value cmpNegativeWithFractionalExp =
- rewriter.create<spirv::LogicalAndOp>(loc, expRemNonZero, lessThan);
+ spirv::LogicalAndOp::create(rewriter, loc, expRemNonZero, lessThan);
// Create NaN result and replace base value if conditions are met.
const auto &floatSemantics = scalarFloatType.getFloatSemantics();
const auto nan = APFloat::getNaN(floatSemantics);
@@ -407,10 +408,11 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
nanAttr = DenseElementsAttr::get(vectorType, nan);
Value NanValue =
- rewriter.create<spirv::ConstantOp>(loc, operandType, nanAttr);
- Value lhs = rewriter.create<spirv::SelectOp>(
- loc, cmpNegativeWithFractionalExp, NanValue, adaptor.getLhs());
- Value abs = rewriter.create<spirv::GLFAbsOp>(loc, lhs);
+ spirv::ConstantOp::create(rewriter, loc, operandType, nanAttr);
+ Value lhs =
+ spirv::SelectOp::create(rewriter, loc, cmpNegativeWithFractionalExp,
+ NanValue, adaptor.getLhs());
+ Value abs = spirv::GLFAbsOp::create(rewriter, loc, lhs);
// TODO: The following just forcefully casts y into an integer value in
// order to properly propagate the sign, assuming integer y cases. It
@@ -418,18 +420,18 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
// Cast exponent to integer and calculate exponent % 2 != 0.
Value intRhs =
- rewriter.create<spirv::ConvertFToSOp>(loc, intType, adaptor.getRhs());
+ spirv::ConvertFToSOp::create(rewriter, loc, intType, adaptor.getRhs());
Value intOne = spirv::ConstantOp::getOne(intType, loc, rewriter);
Value bitwiseAndOne =
- rewriter.create<spirv::BitwiseAndOp>(loc, intRhs, intOne);
- Value isOdd = rewriter.create<spirv::IEqualOp>(loc, bitwiseAndOne, intOne);
+ spirv::BitwiseAndOp::create(rewriter, loc, intRhs, intOne);
+ Value isOdd = spirv::IEqualOp::create(rewriter, loc, bitwiseAndOne, intOne);
// calculate pow based on abs(lhs)^rhs.
- Value pow = rewriter.create<spirv::GLPowOp>(loc, abs, adaptor.getRhs());
- Value negate = rewriter.create<spirv::FNegateOp>(loc, pow);
+ Value pow = spirv::GLPowOp::create(rewriter, loc, abs, adaptor.getRhs());
+ Value negate = spirv::FNegateOp::create(rewriter, loc, pow);
// if the exponent is odd and lhs < 0, negate the result.
Value shouldNegate =
- rewriter.create<spirv::LogicalAndOp>(loc, lessThan, isOdd);
+ spirv::LogicalAndOp::create(rewriter, loc, lessThan, isOdd);
rewriter.replaceOpWithNewOp<spirv::SelectOp>(powfOp, shouldNegate, negate,
pow);
return success();
@@ -455,22 +457,22 @@ struct RoundOpPattern final : public OpConversionPattern<math::RoundOp> {
auto one = spirv::ConstantOp::getOne(ty, loc, rewriter);
Value half;
if (VectorType vty = dyn_cast<VectorType>(ty)) {
- half = rewriter.create<spirv::ConstantOp>(
- loc, vty,
+ half = spirv::ConstantOp::create(
+ rewriter, loc, vty,
DenseElementsAttr::get(vty,
rewriter.getFloatAttr(ety, 0.5).getValue()));
} else {
- half = rewriter.create<spirv::ConstantOp>(
- loc, ty, rewriter.getFloatAttr(ety, 0.5));
+ half = spirv::ConstantOp::create(rewriter, loc, ty,
+ rewriter.getFloatAttr(ety, 0.5));
}
- auto abs = rewriter.create<spirv::GLFAbsOp>(loc, operand);
- auto floor = rewriter.create<spirv::GLFloorOp>(loc, abs);
- auto sub = rewriter.create<spirv::FSubOp>(loc, abs, floor);
+ auto abs = spirv::GLFAbsOp::create(rewriter, loc, operand);
+ auto floor = spirv::GLFloorOp::create(rewriter, loc, abs);
+ auto sub = spirv::FSubOp::create(rewriter, loc, abs, floor);
auto greater =
- rewriter.create<spirv::FOrdGreaterThanEqualOp>(loc, sub, half);
- auto select = rewriter.create<spirv::SelectOp>(loc, greater, one, zero);
- auto add = rewriter.create<spirv::FAddOp>(loc, floor, select);
+ spirv::FOrdGreaterThanEqualOp::create(rewriter, loc, sub, half);
+ auto select = spirv::SelectOp::create(rewriter, loc, greater, one, zero);
+ auto add = spirv::FAddOp::create(rewriter, loc, floor, select);
rewriter.replaceOpWithNewOp<math::CopySignOp>(roundOp, add, operand);
return success();
}
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index 0b7ffa40ec09d..e882845d9d99a 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -160,8 +160,8 @@ struct ConvertGetGlobal final
if (opTy.getRank() == 0) {
emitc::LValueType lvalueType = emitc::LValueType::get(resultTy);
- emitc::GetGlobalOp globalLValue = rewriter.create<emitc::GetGlobalOp>(
- op.getLoc(), lvalueType, operands.getNameAttr());
+ emitc::GetGlobalOp globalLValue = emitc::GetGlobalOp::create(
+ rewriter, op.getLoc(), lvalueType, operands.getNameAttr());
emitc::PointerType pointerType = emitc::PointerType::get(resultTy);
rewriter.replaceOpWithNewOp<emitc::ApplyOp>(
op, pointerType, rewriter.getStringAttr("&"), globalLValue);
@@ -191,8 +191,8 @@ struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
return rewriter.notifyMatchFailure(op.getLoc(), "expected array type");
}
- auto subscript = rewriter.create<emitc::SubscriptOp>(
- op.getLoc(), arrayValue, operands.getIndices());
+ auto subscript = emitc::SubscriptOp::create(
+ rewriter, op.getLoc(), arrayValue, operands.getIndices());
rewriter.replaceOpWithNewOp<emitc::LoadOp>(op, resultTy, subscript);
return success();
@@ -211,8 +211,8 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
return rewriter.notifyMatchFailure(op.getLoc(), "expected array type");
}
- auto subscript = rewriter.create<emitc::SubscriptOp>(
- op.getLoc(), arrayValue, operands.getIndices());
+ auto subscript = emitc::SubscriptOp::create(
+ rewriter, op.getLoc(), arrayValue, operands.getIndices());
rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript,
operands.getValue());
return success();
@@ -242,7 +242,7 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
if (inputs.size() != 1)
return Value();
- return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
+ return UnrealizedConversionCastOp::create(builder, loc, resultType, inputs)
.getResult(0);
};
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 83681b2d5fd87..53a19129103a3 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -87,12 +87,12 @@ getAlignedAllocFn(OpBuilder &b, const LLVMTypeConverter *typeConverter,
/// aligned = bumped - bumped % alignment
static Value createAligned(ConversionPatternRewriter &rewriter, Location loc,
Value input, Value alignment) {
- Value one = rewriter.create<LLVM::ConstantOp>(loc, alignment.getType(),
- rewriter.getIndexAttr(1));
- Value bump = rewriter.create<LLVM::SubOp>(loc, alignment, one);
- Value bumped = rewriter.create<LLVM::AddOp>(loc, input, bump);
- Value mod = rewriter.create<LLVM::URemOp>(loc, bumped, alignment);
- return rewriter.create<LLVM::SubOp>(loc, bumped, mod);
+ Value one = LLVM::ConstantOp::create(rewriter, loc, alignment.getType(),
+ rewriter.getIndexAttr(1));
+ Value bump = LLVM::SubOp::create(rewriter, loc, alignment, one);
+ Value bumped = LLVM::AddOp::create(rewriter, loc, input, bump);
+ Value mod = LLVM::URemOp::create(rewriter, loc, bumped, alignment);
+ return LLVM::SubOp::create(rewriter, loc, bumped, mod);
}
/// Computes the byte size for the MemRef element type.
@@ -123,8 +123,9 @@ static Value castAllocFuncResult(ConversionPatternRewriter &rewriter,
assert(succeeded(maybeMemrefAddrSpace) && "unsupported address space");
unsigned memrefAddrSpace = *maybeMemrefAddrSpace;
if (allocatedPtrTy.getAddressSpace() != memrefAddrSpace)
- allocatedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
- loc, LLVM::LLVMPointerType::get(rewriter.getContext(), memrefAddrSpace),
+ allocatedPtr = LLVM::AddrSpaceCastOp::create(
+ rewriter, loc,
+ LLVM::LLVMPointerType::get(rewriter.getContext(), memrefAddrSpace),
allocatedPtr);
return allocatedPtr;
}
@@ -168,14 +169,14 @@ class AllocOpLowering : public ConvertOpToLLVMPattern<memref::AllocOp> {
Value alignment = getAlignment(rewriter, loc, op);
if (alignment) {
// Adjust the allocation size to consider alignment.
- sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment);
+ sizeBytes = LLVM::AddOp::create(rewriter, loc, sizeBytes, alignment);
}
// Allocate the underlying buffer.
Type elementPtrType = this->getElementPtrType(memRefType);
assert(elementPtrType && "could not compute element ptr type");
auto results =
- rewriter.create<LLVM::CallOp>(loc, allocFuncOp.value(), sizeBytes);
+ LLVM::CallOp::create(rewriter, loc, allocFuncOp.value(), sizeBytes);
Value allocatedPtr =
castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
@@ -184,11 +185,11 @@ class AllocOpLowering : public ConvertOpToLLVMPattern<memref::AllocOp> {
if (alignment) {
// Compute the aligned pointer.
Value allocatedInt =
- rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr);
+ LLVM::PtrToIntOp::create(rewriter, loc, getIndexType(), allocatedPtr);
Value alignmentInt =
createAligned(rewriter, loc, allocatedInt, alignment);
alignedPtr =
- rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt);
+ LLVM::IntToPtrOp::create(rewriter, loc, elementPtrType, alignmentInt);
}
// Create the MemRef descriptor.
@@ -268,8 +269,9 @@ class AlignedAllocOpLowering : public ConvertOpToLLVMPattern<memref::AllocOp> {
sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
Type elementPtrType = this->getElementPtrType(memRefType);
- auto results = rewriter.create<LLVM::CallOp>(
- loc, allocFuncOp.value(), ValueRange({allocAlignment, sizeBytes}));
+ auto results =
+ LLVM::CallOp::create(rewriter, loc, allocFuncOp.value(),
+ ValueRange({allocAlignment, sizeBytes}));
Value ptr =
castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
@@ -360,8 +362,9 @@ struct AllocaOpLowering : public ConvertOpToLLVMPattern<memref::AllocaOp> {
auto elementPtrType =
LLVM::LLVMPointerType::get(rewriter.getContext(), addrSpace);
- auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>(
- loc, elementPtrType, elementType, size, op.getAlignment().value_or(0));
+ auto allocatedElementPtr =
+ LLVM::AllocaOp::create(rewriter, loc, elementPtrType, elementType, size,
+ op.getAlignment().value_or(0));
// Create the MemRef descriptor.
auto memRefDescriptor = this->createMemRefDescriptor(
@@ -397,7 +400,7 @@ struct AllocaScopeOpLowering
remainingOpsBlock, allocaScopeOp.getResultTypes(),
SmallVector<Location>(allocaScopeOp->getNumResults(),
allocaScopeOp.getLoc()));
- rewriter.create<LLVM::BrOp>(loc, ValueRange(), remainingOpsBlock);
+ LLVM::BrOp::create(rewriter, loc, ValueRange(), remainingOpsBlock);
}
// Inline body region.
@@ -407,8 +410,8 @@ struct AllocaScopeOpLowering
// Save stack and then branch into the body of the region.
rewriter.setInsertionPointToEnd(currentBlock);
- auto stackSaveOp = rewriter.create<LLVM::StackSaveOp>(loc, getPtrType());
- rewriter.create<LLVM::BrOp>(loc, ValueRange(), beforeBody);
+ auto stackSaveOp = LLVM::StackSaveOp::create(rewriter, loc, getPtrType());
+ LLVM::BrOp::create(rewriter, loc, ValueRange(), beforeBody);
// Replace the alloca_scope return with a branch that jumps out of the body.
// Stack restore before leaving the body region.
@@ -420,7 +423,7 @@ struct AllocaScopeOpLowering
// Insert stack restore before jumping out the body of the region.
rewriter.setInsertionPoint(branchOp);
- rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
+ LLVM::StackRestoreOp::create(rewriter, loc, stackSaveOp);
// Replace the op with values return from the body region.
rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments());
@@ -451,11 +454,11 @@ struct AssumeAlignmentOpLowering
// This is more direct than ptrtoint-based checks, is explicitly supported,
// and works with non-integral address spaces.
Value trueCond =
- rewriter.create<LLVM::ConstantOp>(loc, rewriter.getBoolAttr(true));
+ LLVM::ConstantOp::create(rewriter, loc, rewriter.getBoolAttr(true));
Value alignmentConst =
createIndexAttrConstant(rewriter, loc, getIndexType(), alignment);
- rewriter.create<LLVM::AssumeOp>(loc, trueCond, LLVM::AssumeAlignTag(), ptr,
- alignmentConst);
+ LLVM::AssumeOp::create(rewriter, loc, trueCond, LLVM::AssumeAlignTag(), ptr,
+ alignmentConst);
rewriter.replaceOp(op, memref);
return success();
}
@@ -559,18 +562,19 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
// Get pointer to offset field of memref<element_type> descriptor.
auto indexPtrTy =
LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
- Value offsetPtr = rewriter.create<LLVM::GEPOp>(
- loc, indexPtrTy, elementType, underlyingRankedDesc,
- ArrayRef<LLVM::GEPArg>{0, 2});
+ Value offsetPtr =
+ LLVM::GEPOp::create(rewriter, loc, indexPtrTy, elementType,
+ underlyingRankedDesc, ArrayRef<LLVM::GEPArg>{0, 2});
// The size value that we have to extract can be obtained using GEPop with
// `dimOp.index() + 1` index argument.
- Value idxPlusOne = rewriter.create<LLVM::AddOp>(
- loc, createIndexAttrConstant(rewriter, loc, getIndexType(), 1),
+ Value idxPlusOne = LLVM::AddOp::create(
+ rewriter, loc,
+ createIndexAttrConstant(rewriter, loc, getIndexType(), 1),
adaptor.getIndex());
- Value sizePtr = rewriter.create<LLVM::GEPOp>(
- loc, indexPtrTy, getTypeConverter()->getIndexType(), offsetPtr,
- idxPlusOne);
+ Value sizePtr = LLVM::GEPOp::create(rewriter, loc, indexPtrTy,
+ getTypeConverter()->getIndexType(),
+ offsetPtr, idxPlusOne);
return rewriter
.create<LLVM::LoadOp>(loc, getTypeConverter()->getIndexType(), sizePtr)
.getResult();
@@ -674,9 +678,10 @@ struct GenericAtomicRMWOpLowering
auto memRefType = cast<MemRefType>(atomicOp.getMemref().getType());
auto dataPtr = getStridedElementPtr(
rewriter, loc, memRefType, adaptor.getMemref(), adaptor.getIndices());
- Value init = rewriter.create<LLVM::LoadOp>(
- loc, typeConverter->convertType(memRefType.getElementType()), dataPtr);
- rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
+ Value init = LLVM::LoadOp::create(
+ rewriter, loc, typeConverter->convertType(memRefType.getElementType()),
+ dataPtr);
+ LLVM::BrOp::create(rewriter, loc, init, loopBlock);
// Prepare the body of the loop block.
rewriter.setInsertionPointToStart(loopBlock);
@@ -696,15 +701,16 @@ struct GenericAtomicRMWOpLowering
// Append the cmpxchg op to the end of the loop block.
auto successOrdering = LLVM::AtomicOrdering::acq_rel;
auto failureOrdering = LLVM::AtomicOrdering::monotonic;
- auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
- loc, dataPtr, loopArgument, result, successOrdering, failureOrdering);
+ auto cmpxchg =
+ LLVM::AtomicCmpXchgOp::create(rewriter, loc, dataPtr, loopArgument,
+ result, successOrdering, failureOrdering);
// Extract the %new_loaded and %ok values from the pair.
- Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(loc, cmpxchg, 0);
- Value ok = rewriter.create<LLVM::ExtractValueOp>(loc, cmpxchg, 1);
+ Value newLoaded = LLVM::ExtractValueOp::create(rewriter, loc, cmpxchg, 0);
+ Value ok = LLVM::ExtractValueOp::create(rewriter, loc, cmpxchg, 1);
// Conditionally branch to the end or back to the loop depending on %ok.
- rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(),
- loopBlock, newLoaded);
+ LLVM::CondBrOp::create(rewriter, loc, ok, endBlock, ArrayRef<Value>(),
+ loopBlock, newLoaded);
rewriter.setInsertionPointToEnd(endBlock);
@@ -796,8 +802,8 @@ class GlobalMemrefOpLowering : public ConvertOpToLLVMPattern<memref::GlobalOp> {
if (!isExternal && isUninitialized) {
rewriter.createBlock(&newGlobal.getInitializerRegion());
Value undef[] = {
- rewriter.create<LLVM::UndefOp>(newGlobal.getLoc(), arrayTy)};
- rewriter.create<LLVM::ReturnOp>(newGlobal.getLoc(), undef);
+ LLVM::UndefOp::create(rewriter, newGlobal.getLoc(), arrayTy)};
+ LLVM::ReturnOp::create(rewriter, newGlobal.getLoc(), undef);
}
return success();
}
@@ -842,13 +848,13 @@ struct GetGlobalMemrefOpLowering
Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), memSpace);
auto addressOf =
- rewriter.create<LLVM::AddressOfOp>(loc, ptrTy, op.getName());
+ LLVM::AddressOfOp::create(rewriter, loc, ptrTy, op.getName());
// Get the address of the first element in the array by creating a GEP with
// the address of the GV as the base, and (rank + 1) number of 0 indices.
- auto gep = rewriter.create<LLVM::GEPOp>(
- loc, ptrTy, arrayTy, addressOf,
- SmallVector<LLVM::GEPArg>(type.getRank() + 1, 0));
+ auto gep =
+ LLVM::GEPOp::create(rewriter, loc, ptrTy, arrayTy, addressOf,
+ SmallVector<LLVM::GEPArg>(type.getRank() + 1, 0));
// We do not expect the memref obtained using `memref.get_global` to be
// ever deallocated. Set the allocated pointer to be known bad value to
@@ -857,7 +863,7 @@ struct GetGlobalMemrefOpLowering
Value deadBeefConst =
createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef);
auto deadBeefPtr =
- rewriter.create<LLVM::IntToPtrOp>(loc, ptrTy, deadBeefConst);
+ LLVM::IntToPtrOp::create(rewriter, loc, ptrTy, deadBeefConst);
// Both allocated and aligned pointers are same. We could potentially stash
// a nullptr for the allocated pointer since we do not expect any dealloc.
@@ -1009,8 +1015,8 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
loc, adaptor.getSource(), rewriter);
// rank = ConstantOp srcRank
- auto rankVal = rewriter.create<LLVM::ConstantOp>(
- loc, getIndexType(), rewriter.getIndexAttr(rank));
+ auto rankVal = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
+ rewriter.getIndexAttr(rank));
// poison = PoisonOp
UnrankedMemRefDescriptor memRefDesc =
UnrankedMemRefDescriptor::poison(rewriter, loc, targetStructType);
@@ -1029,7 +1035,7 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
// struct = LoadOp ptr
- auto loadOp = rewriter.create<LLVM::LoadOp>(loc, targetStructType, ptr);
+ auto loadOp = LLVM::LoadOp::create(rewriter, loc, targetStructType, ptr);
rewriter.replaceOp(memRefCastOp, loadOp.getResult());
} else {
llvm_unreachable("Unsupported unranked memref to unranked memref cast");
@@ -1063,32 +1069,33 @@ class MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
MemRefDescriptor srcDesc(adaptor.getSource());
// Compute number of elements.
- Value numElements = rewriter.create<LLVM::ConstantOp>(
- loc, getIndexType(), rewriter.getIndexAttr(1));
+ Value numElements = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
+ rewriter.getIndexAttr(1));
for (int pos = 0; pos < srcType.getRank(); ++pos) {
auto size = srcDesc.size(rewriter, loc, pos);
- numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size);
+ numElements = LLVM::MulOp::create(rewriter, loc, numElements, size);
}
// Get element size.
auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter);
// Compute total.
Value totalSize =
- rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes);
+ LLVM::MulOp::create(rewriter, loc, numElements, sizeInBytes);
Type elementType = typeConverter->convertType(srcType.getElementType());
Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc);
Value srcOffset = srcDesc.offset(rewriter, loc);
- Value srcPtr = rewriter.create<LLVM::GEPOp>(
- loc, srcBasePtr.getType(), elementType, srcBasePtr, srcOffset);
+ Value srcPtr = LLVM::GEPOp::create(rewriter, loc, srcBasePtr.getType(),
+ elementType, srcBasePtr, srcOffset);
MemRefDescriptor targetDesc(adaptor.getTarget());
Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc);
Value targetOffset = targetDesc.offset(rewriter, loc);
- Value targetPtr = rewriter.create<LLVM::GEPOp>(
- loc, targetBasePtr.getType(), elementType, targetBasePtr, targetOffset);
- rewriter.create<LLVM::MemcpyOp>(loc, targetPtr, srcPtr, totalSize,
- /*isVolatile=*/false);
+ Value targetPtr =
+ LLVM::GEPOp::create(rewriter, loc, targetBasePtr.getType(), elementType,
+ targetBasePtr, targetOffset);
+ LLVM::MemcpyOp::create(rewriter, loc, targetPtr, srcPtr, totalSize,
+ /*isVolatile=*/false);
rewriter.eraseOp(op);
return success();
@@ -1103,8 +1110,8 @@ class MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
// First make sure we have an unranked memref descriptor representation.
auto makeUnranked = [&, this](Value ranked, MemRefType type) {
- auto rank = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
- type.getRank());
+ auto rank = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
+ type.getRank());
auto *typeConverter = getTypeConverter();
auto ptr =
typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);
@@ -1116,7 +1123,7 @@ class MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
};
// Save stack position before promoting descriptors
- auto stackSaveOp = rewriter.create<LLVM::StackSaveOp>(loc, getPtrType());
+ auto stackSaveOp = LLVM::StackSaveOp::create(rewriter, loc, getPtrType());
auto srcMemRefType = dyn_cast<MemRefType>(srcType);
Value unrankedSource =
@@ -1128,13 +1135,13 @@ class MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
: adaptor.getTarget();
// Now promote the unranked descriptors to the stack.
- auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
- rewriter.getIndexAttr(1));
+ auto one = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
+ rewriter.getIndexAttr(1));
auto promote = [&](Value desc) {
auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
auto allocated =
- rewriter.create<LLVM::AllocaOp>(loc, ptrType, desc.getType(), one);
- rewriter.create<LLVM::StoreOp>(loc, desc, allocated);
+ LLVM::AllocaOp::create(rewriter, loc, ptrType, desc.getType(), one);
+ LLVM::StoreOp::create(rewriter, loc, desc, allocated);
return allocated;
};
@@ -1149,11 +1156,11 @@ class MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
sourcePtr.getType(), symbolTables);
if (failed(copyFn))
return failure();
- rewriter.create<LLVM::CallOp>(loc, copyFn.value(),
- ValueRange{elemSize, sourcePtr, targetPtr});
+ LLVM::CallOp::create(rewriter, loc, copyFn.value(),
+ ValueRange{elemSize, sourcePtr, targetPtr});
// Restore stack used for descriptors
- rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
+ LLVM::StackRestoreOp::create(rewriter, loc, stackSaveOp);
rewriter.eraseOp(op);
@@ -1204,9 +1211,9 @@ struct MemorySpaceCastOpLowering
MemRefDescriptor::unpack(rewriter, loc, adaptor.getSource(), resultTypeR,
descVals);
descVals[0] =
- rewriter.create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[0]);
+ LLVM::AddrSpaceCastOp::create(rewriter, loc, newPtrType, descVals[0]);
descVals[1] =
- rewriter.create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[1]);
+ LLVM::AddrSpaceCastOp::create(rewriter, loc, newPtrType, descVals[1]);
Value result = MemRefDescriptor::pack(rewriter, loc, *getTypeConverter(),
resultTypeR, descVals);
rewriter.replaceOp(op, result);
@@ -1241,8 +1248,9 @@ struct MemorySpaceCastOpLowering
UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
result, resultAddrSpace, sizes);
Value resultUnderlyingSize = sizes.front();
- Value resultUnderlyingDesc = rewriter.create<LLVM::AllocaOp>(
- loc, getPtrType(), rewriter.getI8Type(), resultUnderlyingSize);
+ Value resultUnderlyingDesc =
+ LLVM::AllocaOp::create(rewriter, loc, getPtrType(),
+ rewriter.getI8Type(), resultUnderlyingSize);
result.setMemRefDescPtr(rewriter, loc, resultUnderlyingDesc);
// Copy pointers, performing address space casts.
@@ -1256,10 +1264,10 @@ struct MemorySpaceCastOpLowering
Value alignedPtr =
sourceDesc.alignedPtr(rewriter, loc, *getTypeConverter(),
sourceUnderlyingDesc, sourceElemPtrType);
- allocatedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
- loc, resultElemPtrType, allocatedPtr);
- alignedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
- loc, resultElemPtrType, alignedPtr);
+ allocatedPtr = LLVM::AddrSpaceCastOp::create(
+ rewriter, loc, resultElemPtrType, allocatedPtr);
+ alignedPtr = LLVM::AddrSpaceCastOp::create(rewriter, loc,
+ resultElemPtrType, alignedPtr);
result.setAllocatedPtr(rewriter, loc, resultUnderlyingDesc,
resultElemPtrType, allocatedPtr);
@@ -1277,12 +1285,13 @@ struct MemorySpaceCastOpLowering
int64_t bytesToSkip =
2 * llvm::divideCeil(
getTypeConverter()->getPointerBitwidth(resultAddrSpace), 8);
- Value bytesToSkipConst = rewriter.create<LLVM::ConstantOp>(
- loc, getIndexType(), rewriter.getIndexAttr(bytesToSkip));
- Value copySize = rewriter.create<LLVM::SubOp>(
- loc, getIndexType(), resultUnderlyingSize, bytesToSkipConst);
- rewriter.create<LLVM::MemcpyOp>(loc, resultIndexVals, sourceIndexVals,
- copySize, /*isVolatile=*/false);
+ Value bytesToSkipConst = LLVM::ConstantOp::create(
+ rewriter, loc, getIndexType(), rewriter.getIndexAttr(bytesToSkip));
+ Value copySize =
+ LLVM::SubOp::create(rewriter, loc, getIndexType(),
+ resultUnderlyingSize, bytesToSkipConst);
+ LLVM::MemcpyOp::create(rewriter, loc, resultIndexVals, sourceIndexVals,
+ copySize, /*isVolatile=*/false);
rewriter.replaceOp(op, ValueRange{result});
return success();
@@ -1485,7 +1494,7 @@ struct MemRefReshapeOpLowering
} else {
Value shapeOp = reshapeOp.getShape();
Value index = createIndexAttrConstant(rewriter, loc, indexType, i);
- dimSize = rewriter.create<memref::LoadOp>(loc, shapeOp, index);
+ dimSize = memref::LoadOp::create(rewriter, loc, shapeOp, index);
Type indexType = getIndexType();
if (dimSize.getType() != indexType)
dimSize = typeConverter->materializeTargetConversion(
@@ -1497,7 +1506,7 @@ struct MemRefReshapeOpLowering
desc.setStride(rewriter, loc, i, stride);
// Prepare the stride value for the next dimension.
- stride = rewriter.create<LLVM::MulOp>(loc, stride, dimSize);
+ stride = LLVM::MulOp::create(rewriter, loc, stride, dimSize);
}
*descriptor = desc;
@@ -1522,8 +1531,9 @@ struct MemRefReshapeOpLowering
SmallVector<Value, 4> sizes;
UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
targetDesc, addressSpace, sizes);
- Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>(
- loc, getPtrType(), IntegerType::get(getContext(), 8), sizes.front());
+ Value underlyingDescPtr = LLVM::AllocaOp::create(
+ rewriter, loc, getPtrType(), IntegerType::get(getContext(), 8),
+ sizes.front());
targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
// Extract pointers and offset from the source memref.
@@ -1554,7 +1564,7 @@ struct MemRefReshapeOpLowering
Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc);
Value oneIndex = createIndexAttrConstant(rewriter, loc, getIndexType(), 1);
Value resultRankMinusOne =
- rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex);
+ LLVM::SubOp::create(rewriter, loc, resultRank, oneIndex);
Block *initBlock = rewriter.getInsertionBlock();
Type indexType = getTypeConverter()->getIndexType();
@@ -1568,15 +1578,15 @@ struct MemRefReshapeOpLowering
rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange());
rewriter.setInsertionPointToEnd(initBlock);
- rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}),
- condBlock);
+ LLVM::BrOp::create(rewriter, loc,
+ ValueRange({resultRankMinusOne, oneIndex}), condBlock);
rewriter.setInsertionPointToStart(condBlock);
Value indexArg = condBlock->getArgument(0);
Value strideArg = condBlock->getArgument(1);
Value zeroIndex = createIndexAttrConstant(rewriter, loc, indexType, 0);
- Value pred = rewriter.create<LLVM::ICmpOp>(
- loc, IntegerType::get(rewriter.getContext(), 1),
+ Value pred = LLVM::ICmpOp::create(
+ rewriter, loc, IntegerType::get(rewriter.getContext(), 1),
LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
Block *bodyBlock =
@@ -1585,31 +1595,31 @@ struct MemRefReshapeOpLowering
// Copy size from shape to descriptor.
auto llvmIndexPtrType = LLVM::LLVMPointerType::get(rewriter.getContext());
- Value sizeLoadGep = rewriter.create<LLVM::GEPOp>(
- loc, llvmIndexPtrType,
+ Value sizeLoadGep = LLVM::GEPOp::create(
+ rewriter, loc, llvmIndexPtrType,
typeConverter->convertType(shapeMemRefType.getElementType()),
shapeOperandPtr, indexArg);
- Value size = rewriter.create<LLVM::LoadOp>(loc, indexType, sizeLoadGep);
+ Value size = LLVM::LoadOp::create(rewriter, loc, indexType, sizeLoadGep);
UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(),
targetSizesBase, indexArg, size);
// Write stride value and compute next one.
UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(),
targetStridesBase, indexArg, strideArg);
- Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size);
+ Value nextStride = LLVM::MulOp::create(rewriter, loc, strideArg, size);
// Decrement loop counter and branch back.
- Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex);
- rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}),
- condBlock);
+ Value decrement = LLVM::SubOp::create(rewriter, loc, indexArg, oneIndex);
+ LLVM::BrOp::create(rewriter, loc, ValueRange({decrement, nextStride}),
+ condBlock);
Block *remainder =
rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint());
// Hook up the cond exit to the remainder.
rewriter.setInsertionPointToEnd(condBlock);
- rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, ValueRange(),
- remainder, ValueRange());
+ LLVM::CondBrOp::create(rewriter, loc, pred, bodyBlock, ValueRange(),
+ remainder, ValueRange());
// Reset position to beginning of new remainder block.
rewriter.setInsertionPointToStart(remainder);
@@ -1738,7 +1748,7 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
return createIndexAttrConstant(rewriter, loc, indexType, strides[idx]);
if (nextSize)
return runningStride
- ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize)
+ ? LLVM::MulOp::create(rewriter, loc, runningStride, nextSize)
: nextSize;
assert(!runningStride);
return createIndexAttrConstant(rewriter, loc, indexType, 1);
@@ -1783,8 +1793,8 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
// Field 2: Copy the actual aligned pointer to payload.
Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc);
- alignedPtr = rewriter.create<LLVM::GEPOp>(
- loc, alignedPtr.getType(),
+ alignedPtr = LLVM::GEPOp::create(
+ rewriter, loc, alignedPtr.getType(),
typeConverter->convertType(srcMemRefType.getElementType()), alignedPtr,
adaptor.getByteShift());
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index b866afbce98b0..7a705336bf11c 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -79,7 +79,8 @@ adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter,
assert(indices.size() == 2);
indices.back() = builder.createOrFold<spirv::SDivOp>(loc, lastDim, idx);
Type t = typeConverter.convertType(op.getComponentPtr().getType());
- return builder.create<spirv::AccessChainOp>(loc, t, op.getBasePtr(), indices);
+ return spirv::AccessChainOp::create(builder, loc, t, op.getBasePtr(),
+ indices);
}
/// Casts the given `srcBool` into an integer of `dstType`.
@@ -107,8 +108,8 @@ static Value shiftValue(Location loc, Value value, Value offset, Value mask,
value = castBoolToIntN(loc, value, dstType, builder);
} else {
if (valueBits < targetBits) {
- value = builder.create<spirv::UConvertOp>(
- loc, builder.getIntegerType(targetBits), value);
+ value = spirv::UConvertOp::create(
+ builder, loc, builder.getIntegerType(targetBits), value);
}
value = builder.createOrFold<spirv::BitwiseAndOp>(loc, value, mask);
@@ -372,8 +373,8 @@ AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
std::string varName =
std::string("__workgroup_mem__") +
std::to_string(std::distance(varOps.begin(), varOps.end()));
- varOp = rewriter.create<spirv::GlobalVariableOp>(loc, spirvType, varName,
- /*initializer=*/nullptr);
+ varOp = spirv::GlobalVariableOp::create(rewriter, loc, spirvType, varName,
+ /*initializer=*/nullptr);
}
// Get pointer to global variable at the current scope.
@@ -572,8 +573,8 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
loadOp, "failed to determine memory requirements");
auto [memoryAccess, alignment] = *memoryRequirements;
- Value loadVal = rewriter.create<spirv::LoadOp>(loc, accessChain,
- memoryAccess, alignment);
+ Value loadVal = spirv::LoadOp::create(rewriter, loc, accessChain,
+ memoryAccess, alignment);
if (isBool)
loadVal = castIntNToBool(loc, loadVal, rewriter);
rewriter.replaceOp(loadOp, loadVal);
@@ -601,8 +602,8 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
loadOp, "failed to determine memory requirements");
auto [memoryAccess, alignment] = *memoryRequirements;
- Value spvLoadOp = rewriter.create<spirv::LoadOp>(loc, dstType, adjustedPtr,
- memoryAccess, alignment);
+ Value spvLoadOp = spirv::LoadOp::create(rewriter, loc, dstType, adjustedPtr,
+ memoryAccess, alignment);
// Shift the bits to the rightmost.
// ____XXXX________ -> ____________XXXX
@@ -770,12 +771,12 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
if (!scope)
return rewriter.notifyMatchFailure(storeOp, "atomic scope not available");
- Value result = rewriter.create<spirv::AtomicAndOp>(
- loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
- clearBitsMask);
- result = rewriter.create<spirv::AtomicOrOp>(
- loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
- storeVal);
+ Value result = spirv::AtomicAndOp::create(
+ rewriter, loc, dstType, adjustedPtr, *scope,
+ spirv::MemorySemantics::AcquireRelease, clearBitsMask);
+ result = spirv::AtomicOrOp::create(
+ rewriter, loc, dstType, adjustedPtr, *scope,
+ spirv::MemorySemantics::AcquireRelease, storeVal);
// The AtomicOrOp has no side effect. Since it is already inserted, we can
// just remove the original StoreOp. Note that rewriter.replaceOp()
@@ -850,12 +851,12 @@ LogicalResult MemorySpaceCastOpPattern::matchAndRewrite(
genericPtrType = typeConverter.convertType(intermediateType);
}
if (sourceSc != spirv::StorageClass::Generic) {
- result =
- rewriter.create<spirv::PtrCastToGenericOp>(loc, genericPtrType, result);
+ result = spirv::PtrCastToGenericOp::create(rewriter, loc, genericPtrType,
+ result);
}
if (resultSc != spirv::StorageClass::Generic) {
result =
- rewriter.create<spirv::GenericCastToPtrOp>(loc, resultPtrType, result);
+ spirv::GenericCastToPtrOp::create(rewriter, loc, resultPtrType, result);
}
rewriter.replaceOp(addrCastOp, result);
return success();
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index b93128441f2b5..63b1fdabaf407 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -65,7 +65,7 @@ static SmallVector<Value> getMixedAsValues(OpBuilder b, const Location &loc,
values.emplace_back(*(dyn++));
} else {
TypedAttr val = type == i64 ? b.getI64IntegerAttr(s) : b.getIndexAttr(s);
- values.emplace_back(b.create<arith::ConstantOp>(loc, type, val));
+ values.emplace_back(arith::ConstantOp::create(b, loc, type, val));
}
}
return values;
@@ -79,9 +79,9 @@ static SmallVector<Value> linearToMultiIndex(Location loc, OpBuilder b,
SmallVector<Value> multiIndex(n);
for (int i = n - 1; i >= 0; --i) {
- multiIndex[i] = b.create<arith::RemSIOp>(loc, linearIndex, dimensions[i]);
+ multiIndex[i] = arith::RemSIOp::create(b, loc, linearIndex, dimensions[i]);
if (i > 0)
- linearIndex = b.create<arith::DivSIOp>(loc, linearIndex, dimensions[i]);
+ linearIndex = arith::DivSIOp::create(b, loc, linearIndex, dimensions[i]);
}
return multiIndex;
@@ -91,13 +91,13 @@ static SmallVector<Value> linearToMultiIndex(Location loc, OpBuilder b,
Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex,
ValueRange dimensions) {
- Value linearIndex = b.create<arith::ConstantIndexOp>(loc, 0);
- Value stride = b.create<arith::ConstantIndexOp>(loc, 1);
+ Value linearIndex = arith::ConstantIndexOp::create(b, loc, 0);
+ Value stride = arith::ConstantIndexOp::create(b, loc, 1);
for (int i = multiIndex.size() - 1; i >= 0; --i) {
- Value off = b.create<arith::MulIOp>(loc, multiIndex[i], stride);
- linearIndex = b.create<arith::AddIOp>(loc, linearIndex, off);
- stride = b.create<arith::MulIOp>(loc, stride, dimensions[i]);
+ Value off = arith::MulIOp::create(b, loc, multiIndex[i], stride);
+ linearIndex = arith::AddIOp::create(b, loc, linearIndex, off);
+ stride = arith::MulIOp::create(b, loc, stride, dimensions[i]);
}
return linearIndex;
@@ -144,11 +144,12 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
auto i64 = rewriter.getI64Type();
std::array<int64_t, 2> shape = {static_cast<int64_t>(splitAxes.size()),
maxNAxes};
- Value resSplitAxes = rewriter.create<tensor::EmptyOp>(loc, shape, i16);
+ Value resSplitAxes = tensor::EmptyOp::create(rewriter, loc, shape, i16);
auto attr = IntegerAttr::get(i16, -1);
- Value fillValue = rewriter.create<arith::ConstantOp>(loc, i16, attr);
- resSplitAxes = rewriter.create<linalg::FillOp>(loc, fillValue, resSplitAxes)
- .getResult(0);
+ Value fillValue = arith::ConstantOp::create(rewriter, loc, i16, attr);
+ resSplitAxes =
+ linalg::FillOp::create(rewriter, loc, fillValue, resSplitAxes)
+ .getResult(0);
// explicitly write values into tensor row by row
std::array<int64_t, 2> strides = {1, 1};
@@ -162,9 +163,10 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
std::array<int64_t, 2> sizes = {1, size};
auto tensorType = RankedTensorType::get({size}, i16);
auto attrs = DenseIntElementsAttr::get(tensorType, axes.asArrayRef());
- auto vals = rewriter.create<arith::ConstantOp>(loc, tensorType, attrs);
- resSplitAxes = rewriter.create<tensor::InsertSliceOp>(
- loc, vals, resSplitAxes, empty, empty, empty, offs, sizes, strides);
+ auto vals = arith::ConstantOp::create(rewriter, loc, tensorType, attrs);
+ resSplitAxes = tensor::InsertSliceOp::create(rewriter, loc, vals,
+ resSplitAxes, empty, empty,
+ empty, offs, sizes, strides);
}
// To hold halos sizes, create 2d Tensor with shape {nSplits, 2}.
@@ -179,7 +181,7 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
.create<tensor::EmptyOp>(loc, std::array<int64_t, 2>{0, 0},
i64)
.getResult()
- : rewriter.create<tensor::FromElementsOp>(loc, type, haloSizes)
+ : tensor::FromElementsOp::create(rewriter, loc, type, haloSizes)
.getResult();
// To hold sharded dims offsets, create Tensor with shape {nSplits,
@@ -189,8 +191,8 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
// MeshOp)
Value resOffsets;
if (adaptor.getStaticShardedDimsOffsets().empty()) {
- resOffsets = rewriter.create<tensor::EmptyOp>(
- loc, std::array<int64_t, 2>{0, 0}, i64);
+ resOffsets = tensor::EmptyOp::create(rewriter, loc,
+ std::array<int64_t, 2>{0, 0}, i64);
} else {
SymbolTableCollection symbolTableCollection;
auto meshOp = getMesh(op, symbolTableCollection);
@@ -204,12 +206,12 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
assert(maxSplitSize);
++maxSplitSize; // add one for the total size
- resOffsets = rewriter.create<tensor::EmptyOp>(
- loc, std::array<int64_t, 2>{nSplits, maxSplitSize}, i64);
- Value zero = rewriter.create<arith::ConstantOp>(
- loc, i64, rewriter.getI64IntegerAttr(ShapedType::kDynamic));
+ resOffsets = tensor::EmptyOp::create(
+ rewriter, loc, std::array<int64_t, 2>{nSplits, maxSplitSize}, i64);
+ Value zero = arith::ConstantOp::create(
+ rewriter, loc, i64, rewriter.getI64IntegerAttr(ShapedType::kDynamic));
resOffsets =
- rewriter.create<linalg::FillOp>(loc, zero, resOffsets).getResult(0);
+ linalg::FillOp::create(rewriter, loc, zero, resOffsets).getResult(0);
SmallVector<Value> offsets =
getMixedAsValues(rewriter, loc, adaptor.getStaticShardedDimsOffsets(),
adaptor.getDynamicShardedDimsOffsets());
@@ -220,11 +222,12 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
assert(splitSize != ShapedType::kDynamic && splitSize < maxSplitSize);
++splitSize; // add one for the total size
ArrayRef<Value> values(&offsets[curr], splitSize);
- Value vals = rewriter.create<tensor::FromElementsOp>(loc, values);
+ Value vals = tensor::FromElementsOp::create(rewriter, loc, values);
std::array<int64_t, 2> offs = {static_cast<int64_t>(i), 0};
std::array<int64_t, 2> sizes = {1, splitSize};
- resOffsets = rewriter.create<tensor::InsertSliceOp>(
- loc, vals, resOffsets, empty, empty, empty, offs, sizes, strides);
+ resOffsets = tensor::InsertSliceOp::create(rewriter, loc, vals,
+ resOffsets, empty, empty,
+ empty, offs, sizes, strides);
curr += splitSize;
}
}
@@ -236,10 +239,10 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
return failure();
resSplitAxes =
- rewriter.create<tensor::CastOp>(loc, resTypes[0], resSplitAxes);
+ tensor::CastOp::create(rewriter, loc, resTypes[0], resSplitAxes);
resHaloSizes =
- rewriter.create<tensor::CastOp>(loc, resTypes[1], resHaloSizes);
- resOffsets = rewriter.create<tensor::CastOp>(loc, resTypes[2], resOffsets);
+ tensor::CastOp::create(rewriter, loc, resTypes[1], resHaloSizes);
+ resOffsets = tensor::CastOp::create(rewriter, loc, resTypes[2], resOffsets);
rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
op, TupleType::get(op.getContext(), resTypes),
@@ -269,9 +272,9 @@ struct ConvertProcessMultiIndexOp
SmallVector<Value> dims;
llvm::transform(
meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
- return rewriter.create<arith::ConstantIndexOp>(loc, i).getResult();
+ return arith::ConstantIndexOp::create(rewriter, loc, i).getResult();
});
- Value rank = rewriter.create<ProcessLinearIndexOp>(op.getLoc(), meshOp);
+ Value rank = ProcessLinearIndexOp::create(rewriter, op.getLoc(), meshOp);
auto mIdx = linearToMultiIndex(loc, rewriter, rank, dims);
// optionally extract subset of mesh axes
@@ -302,7 +305,7 @@ class ConvertProcessLinearIndexOp
Location loc = op.getLoc();
auto ctx = op.getContext();
Value commWorld =
- rewriter.create<mpi::CommWorldOp>(loc, mpi::CommType::get(ctx));
+ mpi::CommWorldOp::create(rewriter, loc, mpi::CommType::get(ctx));
auto rank =
rewriter
.create<mpi::CommRankOp>(
@@ -341,41 +344,41 @@ struct ConvertNeighborsLinearIndicesOp
SmallVector<Value> dims;
llvm::transform(
meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
- return rewriter.create<arith::ConstantIndexOp>(loc, i).getResult();
+ return arith::ConstantIndexOp::create(rewriter, loc, i).getResult();
});
Value dimSz = dims[axes[0]];
- Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
- Value minus1 = rewriter.create<arith::ConstantIndexOp>(loc, -1);
- Value atBorder = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sle, orgIdx,
- rewriter.create<arith::ConstantIndexOp>(loc, 0));
- auto down = rewriter.create<scf::IfOp>(
- loc, atBorder,
+ Value one = arith::ConstantIndexOp::create(rewriter, loc, 1);
+ Value minus1 = arith::ConstantIndexOp::create(rewriter, loc, -1);
+ Value atBorder =
+ arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sle, orgIdx,
+ arith::ConstantIndexOp::create(rewriter, loc, 0));
+ auto down = scf::IfOp::create(
+ rewriter, loc, atBorder,
[&](OpBuilder &builder, Location loc) {
- builder.create<scf::YieldOp>(loc, minus1);
+ scf::YieldOp::create(builder, loc, minus1);
},
[&](OpBuilder &builder, Location loc) {
SmallVector<Value> tmp = mIdx;
tmp[axes[0]] =
- rewriter.create<arith::SubIOp>(op.getLoc(), orgIdx, one)
+ arith::SubIOp::create(rewriter, op.getLoc(), orgIdx, one)
.getResult();
- builder.create<scf::YieldOp>(
- loc, multiToLinearIndex(loc, rewriter, tmp, dims));
+ scf::YieldOp::create(builder, loc,
+ multiToLinearIndex(loc, rewriter, tmp, dims));
});
- atBorder = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sge, orgIdx,
- rewriter.create<arith::SubIOp>(loc, dimSz, one).getResult());
- auto up = rewriter.create<scf::IfOp>(
- loc, atBorder,
+ atBorder = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::sge, orgIdx,
+ arith::SubIOp::create(rewriter, loc, dimSz, one).getResult());
+ auto up = scf::IfOp::create(
+ rewriter, loc, atBorder,
[&](OpBuilder &builder, Location loc) {
- builder.create<scf::YieldOp>(loc, minus1);
+ scf::YieldOp::create(builder, loc, minus1);
},
[&](OpBuilder &builder, Location loc) {
SmallVector<Value> tmp = mIdx;
tmp[axes[0]] =
- rewriter.create<arith::AddIOp>(op.getLoc(), orgIdx, one);
- builder.create<scf::YieldOp>(
- loc, multiToLinearIndex(loc, rewriter, tmp, dims));
+ arith::AddIOp::create(rewriter, op.getLoc(), orgIdx, one);
+ scf::YieldOp::create(builder, loc,
+ multiToLinearIndex(loc, rewriter, tmp, dims));
});
rewriter.replaceOp(op, ValueRange{down.getResult(0), up.getResult(0)});
return success();
@@ -447,8 +450,9 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
rewriter, loc, sharding.getStaticShardedDimsOffsets(),
sharding.getDynamicShardedDimsOffsets(), index);
if (!tmp.empty())
- shardedDimsOffs = rewriter.create<tensor::FromElementsOp>(
- loc, RankedTensorType::get({(int64_t)tmp.size()}, index), tmp);
+ shardedDimsOffs = tensor::FromElementsOp::create(
+ rewriter, loc, RankedTensorType::get({(int64_t)tmp.size()}, index),
+ tmp);
}
// With static mesh shape the sizes of the split axes are known.
@@ -457,9 +461,9 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
int64_t pos = 0;
SmallVector<Value> shardShape;
Value zero =
- rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(index));
+ arith::ConstantOp::create(rewriter, loc, rewriter.getZeroAttr(index));
Value one =
- rewriter.create<arith::ConstantOp>(loc, rewriter.getOneAttr(index));
+ arith::ConstantOp::create(rewriter, loc, rewriter.getOneAttr(index));
// Iterate over the dimensions of the tensor shape, get their split Axes,
// and compute the sharded shape.
@@ -469,8 +473,8 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
auto axes = splitAxes[i];
// The current dimension might not be sharded.
// Create a value from the static position in shardDimsOffsets.
- Value posVal =
- rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(pos));
+ Value posVal = arith::ConstantOp::create(rewriter, loc,
+ rewriter.getIndexAttr(pos));
// Get the index of the local shard in the mesh axis.
Value idx = multiIdx[axes[0]];
auto numShards =
@@ -482,29 +486,29 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
return op->emitError() << "Only single axis sharding is "
<< "supported for each dimension.";
}
- idx = rewriter.create<arith::AddIOp>(loc, posVal, idx);
+ idx = arith::AddIOp::create(rewriter, loc, posVal, idx);
// Compute size = shardedDimsOffs[idx+1] - shardedDimsOffs[idx].
Value off =
- rewriter.create<tensor::ExtractOp>(loc, shardedDimsOffs, idx);
- idx = rewriter.create<arith::AddIOp>(loc, idx, one);
+ tensor::ExtractOp::create(rewriter, loc, shardedDimsOffs, idx);
+ idx = arith::AddIOp::create(rewriter, loc, idx, one);
Value nextOff =
- rewriter.create<tensor::ExtractOp>(loc, shardedDimsOffs, idx);
- Value sz = rewriter.create<arith::SubIOp>(loc, nextOff, off);
+ tensor::ExtractOp::create(rewriter, loc, shardedDimsOffs, idx);
+ Value sz = arith::SubIOp::create(rewriter, loc, nextOff, off);
shardShape.emplace_back(sz);
} else {
- Value numShardsVal = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIndexAttr(numShards));
+ Value numShardsVal = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getIndexAttr(numShards));
// Compute shard dim size by distributing odd elements to trailing
// shards:
// sz = dim / numShards
// + (idx >= (numShards - (dim % numShards)) ? 1 : 0)
- Value sz = rewriter.create<arith::DivSIOp>(loc, dim, numShardsVal);
- Value sz1 = rewriter.create<arith::RemSIOp>(loc, dim, numShardsVal);
- sz1 = rewriter.create<arith::SubIOp>(loc, numShardsVal, sz1);
- auto cond = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sge, idx, sz1);
- Value odd = rewriter.create<arith::SelectOp>(loc, cond, one, zero);
- sz = rewriter.create<arith::AddIOp>(loc, sz, odd);
+ Value sz = arith::DivSIOp::create(rewriter, loc, dim, numShardsVal);
+ Value sz1 = arith::RemSIOp::create(rewriter, loc, dim, numShardsVal);
+ sz1 = arith::SubIOp::create(rewriter, loc, numShardsVal, sz1);
+ auto cond = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::sge, idx, sz1);
+ Value odd = arith::SelectOp::create(rewriter, loc, cond, one, zero);
+ sz = arith::AddIOp::create(rewriter, loc, sz, odd);
shardShape.emplace_back(sz);
}
pos += numShards + 1; // add one for the total size.
@@ -568,7 +572,7 @@ struct ConvertAllReduceOp : public OpConversionPattern<AllReduceOp> {
if (isa<RankedTensorType>(input.getType())) {
auto memrefType = MemRefType::get(
inputShape, cast<ShapedType>(input.getType()).getElementType());
- input = iBuilder.create<bufferization::ToBufferOp>(memrefType, input);
+ input = bufferization::ToBufferOp::create(iBuilder, memrefType, input);
}
MemRefType inType = cast<MemRefType>(input.getType());
@@ -577,15 +581,15 @@ struct ConvertAllReduceOp : public OpConversionPattern<AllReduceOp> {
for (auto i = 0; i < inType.getRank(); ++i) {
auto s = inputShape[i];
if (ShapedType::isDynamic(s))
- shape[i] = iBuilder.create<memref::DimOp>(input, s).getResult();
+ shape[i] = memref::DimOp::create(iBuilder, input, s).getResult();
else
shape[i] = iBuilder.getIndexAttr(s);
}
// Allocate buffer and copy input to buffer.
- Value buffer = iBuilder.create<memref::AllocOp>(
- shape, cast<ShapedType>(op.getType()).getElementType());
- iBuilder.create<linalg::CopyOp>(input, buffer);
+ Value buffer = memref::AllocOp::create(
+ iBuilder, shape, cast<ShapedType>(op.getType()).getElementType());
+ linalg::CopyOp::create(iBuilder, input, buffer);
// Get an MPI_Comm_split for the AllReduce operation.
// The color is the linear index of the process in the mesh along the
@@ -594,9 +598,9 @@ struct ConvertAllReduceOp : public OpConversionPattern<AllReduceOp> {
SmallVector<Type> indexResultTypes(meshOp.getShape().size(),
iBuilder.getIndexType());
SmallVector<Value> myMultiIndex =
- iBuilder.create<ProcessMultiIndexOp>(indexResultTypes, mesh)
+ ProcessMultiIndexOp::create(iBuilder, indexResultTypes, mesh)
.getResult();
- Value zero = iBuilder.create<arith::ConstantIndexOp>(0);
+ Value zero = arith::ConstantIndexOp::create(iBuilder, 0);
SmallVector<Value> multiKey(myMultiIndex.size(), zero);
auto redAxes = adaptor.getMeshAxes();
@@ -607,15 +611,15 @@ struct ConvertAllReduceOp : public OpConversionPattern<AllReduceOp> {
Value color =
createProcessLinearIndex(mesh, myMultiIndex, redAxes, iBuilder);
- color = iBuilder.create<arith::IndexCastOp>(iBuilder.getI32Type(), color);
+ color = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), color);
Value key = createProcessLinearIndex(mesh, multiKey, redAxes, iBuilder);
- key = iBuilder.create<arith::IndexCastOp>(iBuilder.getI32Type(), key);
+ key = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), key);
// Finally split the communicator
auto commType = mpi::CommType::get(op->getContext());
- Value commWorld = iBuilder.create<mpi::CommWorldOp>(commType);
+ Value commWorld = mpi::CommWorldOp::create(iBuilder, commType);
auto comm =
- iBuilder.create<mpi::CommSplitOp>(commType, commWorld, color, key)
+ mpi::CommSplitOp::create(iBuilder, commType, commWorld, color, key)
.getNewcomm();
Value buffer1d = buffer;
@@ -623,19 +627,19 @@ struct ConvertAllReduceOp : public OpConversionPattern<AllReduceOp> {
if (inType.getRank() > 1) {
ReassociationIndices reassociation(inType.getRank());
std::iota(reassociation.begin(), reassociation.end(), 0);
- buffer1d = iBuilder.create<memref::CollapseShapeOp>(
- buffer, ArrayRef<ReassociationIndices>(reassociation));
+ buffer1d = memref::CollapseShapeOp::create(
+ iBuilder, buffer, ArrayRef<ReassociationIndices>(reassociation));
}
// Create the MPI AllReduce operation.
- iBuilder.create<mpi::AllReduceOp>(
- TypeRange(), buffer1d, buffer1d,
- getMPIReductionOp(adaptor.getReductionAttr()), comm);
+ mpi::AllReduceOp::create(iBuilder, TypeRange(), buffer1d, buffer1d,
+ getMPIReductionOp(adaptor.getReductionAttr()),
+ comm);
// If the destination is a memref, cast it to a tensor
if (isa<RankedTensorType>(op.getType()))
- buffer = iBuilder.create<bufferization::ToTensorOp>(op.getType(), buffer,
- true);
+ buffer = bufferization::ToTensorOp::create(iBuilder, op.getType(), buffer,
+ true);
rewriter.replaceOp(op, buffer);
return success();
@@ -676,9 +680,10 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
auto toValue = [&rewriter, &loc](OpFoldResult &v) -> Value {
if (auto value = dyn_cast<Value>(v))
return value;
- return rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIndexAttr(
- cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
+ return arith::ConstantOp::create(
+ rewriter, loc,
+ rewriter.getIndexAttr(
+ cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
};
auto dest = adaptor.getDestination();
@@ -689,7 +694,7 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
auto mmemrefType = MemRefType::get(
dstShape, cast<ShapedType>(array.getType()).getElementType());
array =
- rewriter.create<bufferization::ToBufferOp>(loc, mmemrefType, array);
+ bufferization::ToBufferOp::create(rewriter, loc, mmemrefType, array);
}
auto rank = cast<ShapedType>(array.getType()).getRank();
auto opSplitAxes = adaptor.getSplitAxes().getAxes();
@@ -713,7 +718,7 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
for (auto i = 0; i < rank; ++i) {
auto s = dstShape[i];
if (ShapedType::isDynamic(s))
- shape[i] = rewriter.create<memref::DimOp>(loc, array, s).getResult();
+ shape[i] = memref::DimOp::create(rewriter, loc, array, s).getResult();
else
shape[i] = rewriter.getIndexAttr(s);
@@ -723,12 +728,12 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
offsets[i] = haloSizes[currHaloDim * 2];
// prepare shape and offsets of highest dim's halo exchange
- Value _haloSz = rewriter.create<arith::AddIOp>(
- loc, toValue(haloSizes[currHaloDim * 2]),
+ Value _haloSz = arith::AddIOp::create(
+ rewriter, loc, toValue(haloSizes[currHaloDim * 2]),
toValue(haloSizes[currHaloDim * 2 + 1]));
// the halo shape of lower dims exlude the halos
dimSizes[i] =
- rewriter.create<arith::SubIOp>(loc, toValue(shape[i]), _haloSz)
+ arith::SubIOp::create(rewriter, loc, toValue(shape[i]), _haloSz)
.getResult();
} else {
dimSizes[i] = shape[i];
@@ -736,14 +741,14 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
}
auto tagAttr = rewriter.getI32IntegerAttr(91); // we just pick something
- auto tag = rewriter.create<arith::ConstantOp>(loc, tagAttr);
+ auto tag = arith::ConstantOp::create(rewriter, loc, tagAttr);
auto zeroAttr = rewriter.getI32IntegerAttr(0); // for detecting v<0
- auto zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
+ auto zero = arith::ConstantOp::create(rewriter, loc, zeroAttr);
SmallVector<Type> indexResultTypes(meshOp.getShape().size(),
rewriter.getIndexType());
auto myMultiIndex =
- rewriter.create<ProcessMultiIndexOp>(loc, indexResultTypes, mesh)
+ ProcessMultiIndexOp::create(rewriter, loc, indexResultTypes, mesh)
.getResult();
// traverse all split axes from high to low dim
for (ssize_t dim = opSplitAxes.size() - 1; dim >= 0; --dim) {
@@ -758,20 +763,22 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
splitAxes)
.getResults();
// MPI operates on i32...
- Value neighbourIDs[2] = {rewriter.create<arith::IndexCastOp>(
- loc, rewriter.getI32Type(), tmp[0]),
- rewriter.create<arith::IndexCastOp>(
- loc, rewriter.getI32Type(), tmp[1])};
+ Value neighbourIDs[2] = {
+ arith::IndexCastOp::create(rewriter, loc, rewriter.getI32Type(),
+ tmp[0]),
+ arith::IndexCastOp::create(rewriter, loc, rewriter.getI32Type(),
+ tmp[1])};
auto lowerRecvOffset = rewriter.getIndexAttr(0);
auto lowerSendOffset = toValue(haloSizes[currHaloDim * 2]);
- auto upperRecvOffset = rewriter.create<arith::SubIOp>(
- loc, toValue(shape[dim]), toValue(haloSizes[currHaloDim * 2 + 1]));
- auto upperSendOffset = rewriter.create<arith::SubIOp>(
- loc, upperRecvOffset, toValue(haloSizes[currHaloDim * 2]));
+ auto upperRecvOffset =
+ arith::SubIOp::create(rewriter, loc, toValue(shape[dim]),
+ toValue(haloSizes[currHaloDim * 2 + 1]));
+ auto upperSendOffset = arith::SubIOp::create(
+ rewriter, loc, upperRecvOffset, toValue(haloSizes[currHaloDim * 2]));
- Value commWorld = rewriter.create<mpi::CommWorldOp>(
- loc, mpi::CommType::get(op->getContext()));
+ Value commWorld = mpi::CommWorldOp::create(
+ rewriter, loc, mpi::CommType::get(op->getContext()));
// Make sure we send/recv in a way that does not lead to a dead-lock.
// The current approach is by far not optimal, this should be at least
@@ -787,37 +794,38 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
// Processes on the mesh borders have only one neighbor
auto to = upperHalo ? neighbourIDs[0] : neighbourIDs[1];
auto from = upperHalo ? neighbourIDs[1] : neighbourIDs[0];
- auto hasFrom = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sge, from, zero);
- auto hasTo = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sge, to, zero);
- auto buffer = rewriter.create<memref::AllocOp>(
- loc, dimSizes, cast<ShapedType>(array.getType()).getElementType());
+ auto hasFrom = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::sge, from, zero);
+ auto hasTo = arith::CmpIOp::create(rewriter, loc,
+ arith::CmpIPredicate::sge, to, zero);
+ auto buffer = memref::AllocOp::create(
+ rewriter, loc, dimSizes,
+ cast<ShapedType>(array.getType()).getElementType());
// if has neighbor: copy halo data from array to buffer and send
- rewriter.create<scf::IfOp>(
- loc, hasTo, [&](OpBuilder &builder, Location loc) {
+ scf::IfOp::create(
+ rewriter, loc, hasTo, [&](OpBuilder &builder, Location loc) {
offsets[dim] = upperHalo ? OpFoldResult(lowerSendOffset)
: OpFoldResult(upperSendOffset);
- auto subview = builder.create<memref::SubViewOp>(
- loc, array, offsets, dimSizes, strides);
- builder.create<memref::CopyOp>(loc, subview, buffer);
- builder.create<mpi::SendOp>(loc, TypeRange{}, buffer, tag, to,
- commWorld);
- builder.create<scf::YieldOp>(loc);
+ auto subview = memref::SubViewOp::create(
+ builder, loc, array, offsets, dimSizes, strides);
+ memref::CopyOp::create(builder, loc, subview, buffer);
+ mpi::SendOp::create(builder, loc, TypeRange{}, buffer, tag, to,
+ commWorld);
+ scf::YieldOp::create(builder, loc);
});
// if has neighbor: receive halo data into buffer and copy to array
- rewriter.create<scf::IfOp>(
- loc, hasFrom, [&](OpBuilder &builder, Location loc) {
+ scf::IfOp::create(
+ rewriter, loc, hasFrom, [&](OpBuilder &builder, Location loc) {
offsets[dim] = upperHalo ? OpFoldResult(upperRecvOffset)
: OpFoldResult(lowerRecvOffset);
- builder.create<mpi::RecvOp>(loc, TypeRange{}, buffer, tag, from,
- commWorld);
- auto subview = builder.create<memref::SubViewOp>(
- loc, array, offsets, dimSizes, strides);
- builder.create<memref::CopyOp>(loc, buffer, subview);
- builder.create<scf::YieldOp>(loc);
+ mpi::RecvOp::create(builder, loc, TypeRange{}, buffer, tag, from,
+ commWorld);
+ auto subview = memref::SubViewOp::create(
+ builder, loc, array, offsets, dimSizes, strides);
+ memref::CopyOp::create(builder, loc, buffer, subview);
+ scf::YieldOp::create(builder, loc);
});
- rewriter.create<memref::DeallocOp>(loc, buffer);
+ memref::DeallocOp::create(rewriter, loc, buffer);
offsets[dim] = orgOffset;
};
@@ -825,16 +833,17 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
OpFoldResult &v = haloSizes[currHaloDim * 2 + upOrDown];
Value haloSz = dyn_cast<Value>(v);
if (!haloSz)
- haloSz = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI32IntegerAttr(
- cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
- auto hasSize = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sgt, haloSz, zero);
- rewriter.create<scf::IfOp>(loc, hasSize,
- [&](OpBuilder &builder, Location loc) {
- genSendRecv(upOrDown > 0);
- builder.create<scf::YieldOp>(loc);
- });
+ haloSz = arith::ConstantOp::create(
+ rewriter, loc,
+ rewriter.getI32IntegerAttr(
+ cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
+ auto hasSize = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::sgt, haloSz, zero);
+ scf::IfOp::create(rewriter, loc, hasSize,
+ [&](OpBuilder &builder, Location loc) {
+ genSendRecv(upOrDown > 0);
+ scf::YieldOp::create(builder, loc);
+ });
};
doSendRecv(0);
@@ -852,8 +861,8 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
rewriter.replaceOp(op, array);
} else {
assert(isa<RankedTensorType>(op.getResult().getType()));
- rewriter.replaceOp(op, rewriter.create<bufferization::ToTensorOp>(
- loc, op.getResult().getType(), array,
+ rewriter.replaceOp(op, bufferization::ToTensorOp::create(
+ rewriter, loc, op.getResult().getType(), array,
/*restrict=*/true, /*writable=*/true));
}
return success();
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 80b3d85488495..905287e107b0b 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -53,7 +53,7 @@ static Value truncToI32(ImplicitLocOpBuilder &b, Value value) {
assert(llvm::isa<IntegerType>(type) && "expected an integer Value");
if (type.getIntOrFloatBitWidth() <= 32)
return value;
- return b.create<LLVM::TruncOp>(b.getI32Type(), value);
+ return LLVM::TruncOp::create(b, b.getI32Type(), value);
}
/// Returns the type for the intrinsic given the vectorResultType of the
@@ -113,8 +113,8 @@ static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
Type f32x1Ty = VectorType::get(1, f32Ty);
auto makeConst = [&](int32_t index) -> Value {
- return rewriter.create<LLVM::ConstantOp>(loc, IntegerType::get(ctx, 32),
- rewriter.getI32IntegerAttr(index));
+ return LLVM::ConstantOp::create(rewriter, loc, IntegerType::get(ctx, 32),
+ rewriter.getI32IntegerAttr(index));
};
if (arrayType) {
@@ -126,7 +126,7 @@ static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
arrayType.getElementType() == f32x1Ty) {
for (unsigned i = 0; i < structType.getBody().size(); i++) {
Value el =
- rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult, i);
+ LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult, i);
el = rewriter.createOrFold<LLVM::BitcastOp>(
loc, arrayType.getElementType(), el);
elements.push_back(el);
@@ -143,24 +143,24 @@ static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
for (unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) {
Value vec =
- rewriter.create<LLVM::PoisonOp>(loc, arrayType.getElementType());
+ LLVM::PoisonOp::create(rewriter, loc, arrayType.getElementType());
Value x1 =
- rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult, i * 2);
- Value x2 = rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult,
- i * 2 + 1);
- vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec,
- x1, makeConst(0));
- vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec,
- x2, makeConst(1));
+ LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult, i * 2);
+ Value x2 = LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult,
+ i * 2 + 1);
+ vec = LLVM::InsertElementOp::create(rewriter, loc, vec.getType(), vec,
+ x1, makeConst(0));
+ vec = LLVM::InsertElementOp::create(rewriter, loc, vec.getType(), vec,
+ x2, makeConst(1));
elements.push_back(vec);
}
}
// Create the final vectorized result.
- Value result = rewriter.create<LLVM::PoisonOp>(loc, arrayType);
+ Value result = LLVM::PoisonOp::create(rewriter, loc, arrayType);
for (const auto &el : llvm::enumerate(elements)) {
- result = rewriter.create<LLVM::InsertValueOp>(loc, result, el.value(),
- el.index());
+ result = LLVM::InsertValueOp::create(rewriter, loc, result, el.value(),
+ el.index());
}
return result;
}
@@ -187,7 +187,7 @@ static SmallVector<Value> unpackOperandVector(ImplicitLocOpBuilder &b,
auto arrayTy = cast<LLVM::LLVMArrayType>(operand.getType());
for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
- Value toUse = b.create<LLVM::ExtractValueOp>(operand, i);
+ Value toUse = LLVM::ExtractValueOp::create(b, operand, i);
// For 4xi8 vectors, the intrinsic expects these to be provided as i32
// scalar types.
@@ -195,7 +195,7 @@ static SmallVector<Value> unpackOperandVector(ImplicitLocOpBuilder &b,
arrayTy.getElementType() == i4x8Ty ||
(arrayTy.getElementType() == f32x1Ty &&
operandPtxType == NVVM::MMATypes::tf32)) {
- result.push_back(b.create<LLVM::BitcastOp>(i32Ty, toUse));
+ result.push_back(LLVM::BitcastOp::create(b, i32Ty, toUse));
continue;
}
@@ -208,9 +208,9 @@ static SmallVector<Value> unpackOperandVector(ImplicitLocOpBuilder &b,
innerArrayTy.getElementType() == f32Ty)) {
for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements();
idx < innerSize; idx++) {
- result.push_back(b.create<LLVM::ExtractElementOp>(
- toUse,
- b.create<LLVM::ConstantOp>(i64Ty, b.getI64IntegerAttr(idx))));
+ result.push_back(LLVM::ExtractElementOp::create(
+ b, toUse,
+ LLVM::ConstantOp::create(b, i64Ty, b.getI64IntegerAttr(idx))));
}
continue;
}
@@ -285,8 +285,8 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
Value srcPtr =
getStridedElementPtr(rewriter, b.getLoc(), srcMemrefType,
adaptor.getSrcMemref(), adaptor.getIndices());
- Value ldMatrixResult = b.create<NVVM::LdMatrixOp>(
- ldMatrixResultType, srcPtr,
+ Value ldMatrixResult = NVVM::LdMatrixOp::create(
+ b, ldMatrixResultType, srcPtr,
/*num=*/op.getNumTiles(),
/*layout=*/op.getTranspose() ? NVVM::MMALayout::col
: NVVM::MMALayout::row);
@@ -296,13 +296,13 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
// actual vector type (still of width 32b) and repack them into a result
// struct.
Type finalResultType = typeConverter->convertType(vectorResultType);
- Value result = b.create<LLVM::PoisonOp>(finalResultType);
+ Value result = LLVM::PoisonOp::create(b, finalResultType);
for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) {
Value i32Register =
- num32BitRegs > 1 ? b.create<LLVM::ExtractValueOp>(ldMatrixResult, i)
+ num32BitRegs > 1 ? LLVM::ExtractValueOp::create(b, ldMatrixResult, i)
: ldMatrixResult;
- Value casted = b.create<LLVM::BitcastOp>(innerVectorType, i32Register);
- result = b.create<LLVM::InsertValueOp>(result, casted, i);
+ Value casted = LLVM::BitcastOp::create(b, innerVectorType, i32Register);
+ result = LLVM::InsertValueOp::create(b, result, casted, i);
}
rewriter.replaceOp(op, result);
@@ -375,16 +375,16 @@ struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
Type intrinsicResTy = inferIntrinsicResultType(
typeConverter->convertType(op->getResultTypes()[0]));
- Value intrinsicResult = b.create<NVVM::MmaOp>(
- intrinsicResTy, matA, matB, matC,
- /*shape=*/gemmShape,
- /*b1Op=*/std::nullopt,
- /*intOverflow=*/overflow,
- /*multiplicandPtxTypes=*/
- std::array<NVVM::MMATypes, 2>{*ptxTypeA, *ptxTypeB},
- /*multiplicandLayouts=*/
- std::array<NVVM::MMALayout, 2>{NVVM::MMALayout::row,
- NVVM::MMALayout::col});
+ Value intrinsicResult =
+ NVVM::MmaOp::create(b, intrinsicResTy, matA, matB, matC,
+ /*shape=*/gemmShape,
+ /*b1Op=*/std::nullopt,
+ /*intOverflow=*/overflow,
+ /*multiplicandPtxTypes=*/
+ std::array<NVVM::MMATypes, 2>{*ptxTypeA, *ptxTypeB},
+ /*multiplicandLayouts=*/
+ std::array<NVVM::MMALayout, 2>{
+ NVVM::MMALayout::row, NVVM::MMALayout::col});
rewriter.replaceOp(op, convertIntrinsicResult(op.getLoc(), intrinsicResTy,
desiredRetTy, intrinsicResult,
rewriter));
@@ -565,15 +565,16 @@ static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
llvm::append_range(asmVals, args);
asmVals.push_back(indexData);
- return b.create<LLVM::InlineAsmOp>(
- /*resultTypes=*/intrinsicResultType,
- /*operands=*/asmVals,
- /*asm_string=*/asmStr,
- /*constraints=*/constraintStr,
- /*has_side_effects=*/true,
- /*is_align_stack=*/false, LLVM::TailCallKind::None,
- /*asm_dialect=*/asmDialectAttr,
- /*operand_attrs=*/ArrayAttr());
+ return LLVM::InlineAsmOp::create(b,
+ /*resultTypes=*/intrinsicResultType,
+ /*operands=*/asmVals,
+ /*asm_string=*/asmStr,
+ /*constraints=*/constraintStr,
+ /*has_side_effects=*/true,
+ /*is_align_stack=*/false,
+ LLVM::TailCallKind::None,
+ /*asm_dialect=*/asmDialectAttr,
+ /*operand_attrs=*/ArrayAttr());
}
/// Lowers `nvgpu.mma.sp.sync` to inline assembly.
@@ -631,7 +632,7 @@ struct NVGPUMmaSparseSyncLowering
return op->emitOpError() << "Expected metadata type to be LLVM "
"VectorType of 2 i16 elements";
sparseMetadata =
- b.create<LLVM::BitcastOp>(rewriter.getI32Type(), sparseMetadata);
+ LLVM::BitcastOp::create(b, rewriter.getI32Type(), sparseMetadata);
FailureOr<LLVM::InlineAsmOp> intrinsicResult = emitMmaSparseSyncOpAsm(
b, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB,
@@ -682,7 +683,7 @@ struct NVGPUAsyncCopyLowering
// Intrinsics takes a global pointer so we need an address space cast.
auto srcPointerGlobalType = LLVM::LLVMPointerType::get(
op->getContext(), NVVM::NVVMMemorySpace::kGlobalMemorySpace);
- scrPtr = b.create<LLVM::AddrSpaceCastOp>(srcPointerGlobalType, scrPtr);
+ scrPtr = LLVM::AddrSpaceCastOp::create(b, srcPointerGlobalType, scrPtr);
int64_t dstElements = adaptor.getDstElements().getZExtValue();
int64_t sizeInBytes =
(dstMemrefType.getElementTypeBitWidth() * dstElements) / 8;
@@ -697,13 +698,13 @@ struct NVGPUAsyncCopyLowering
// The rest of the DstElements in the destination (shared memory) are
// filled with zeros.
Value c3I32 =
- b.create<LLVM::ConstantOp>(b.getI32Type(), b.getI32IntegerAttr(3));
- Value bitwidth = b.create<LLVM::ConstantOp>(
- b.getI32Type(),
+ LLVM::ConstantOp::create(b, b.getI32Type(), b.getI32IntegerAttr(3));
+ Value bitwidth = LLVM::ConstantOp::create(
+ b, b.getI32Type(),
b.getI32IntegerAttr(srcMemrefType.getElementTypeBitWidth()));
- Value srcElementsI32 = b.create<LLVM::TruncOp>(b.getI32Type(), srcBytes);
- srcBytes = b.create<LLVM::LShrOp>(
- b.create<LLVM::MulOp>(bitwidth, srcElementsI32), c3I32);
+ Value srcElementsI32 = LLVM::TruncOp::create(b, b.getI32Type(), srcBytes);
+ srcBytes = LLVM::LShrOp::create(
+ b, LLVM::MulOp::create(b, bitwidth, srcElementsI32), c3I32);
}
// Cache global (.cg) for 16 dst bytes, Cache all (.ca) for sizes other than
// 16 dst bytes.
@@ -712,14 +713,15 @@ struct NVGPUAsyncCopyLowering
? NVVM::LoadCacheModifierKind::CG
: NVVM::LoadCacheModifierKind::CA;
- b.create<NVVM::CpAsyncOp>(
- dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes),
+ NVVM::CpAsyncOp::create(
+ b, dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes),
NVVM::LoadCacheModifierKindAttr::get(op->getContext(), cacheModifier),
srcBytes);
// Drop the result token.
- Value zero = b.create<LLVM::ConstantOp>(
- IntegerType::get(op.getContext(), 32), rewriter.getI32IntegerAttr(0));
+ Value zero =
+ LLVM::ConstantOp::create(b, IntegerType::get(op.getContext(), 32),
+ rewriter.getI32IntegerAttr(0));
rewriter.replaceOp(op, zero);
return success();
}
@@ -733,11 +735,11 @@ struct NVGPUAsyncCreateGroupLowering
LogicalResult
matchAndRewrite(nvgpu::DeviceAsyncCreateGroupOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- rewriter.create<NVVM::CpAsyncCommitGroupOp>(op.getLoc());
+ NVVM::CpAsyncCommitGroupOp::create(rewriter, op.getLoc());
// Drop the result token.
- Value zero = rewriter.create<LLVM::ConstantOp>(
- op->getLoc(), IntegerType::get(op.getContext(), 32),
- rewriter.getI32IntegerAttr(0));
+ Value zero = LLVM::ConstantOp::create(rewriter, op->getLoc(),
+ IntegerType::get(op.getContext(), 32),
+ rewriter.getI32IntegerAttr(0));
rewriter.replaceOp(op, zero);
return success();
}
@@ -753,7 +755,7 @@ struct NVGPUAsyncWaitLowering
ConversionPatternRewriter &rewriter) const override {
// If numGroup is not present pick 0 as a conservative correct value.
int32_t numGroups = adaptor.getNumGroups().value_or(0);
- rewriter.create<NVVM::CpAsyncWaitGroupOp>(op.getLoc(), numGroups);
+ NVVM::CpAsyncWaitGroupOp::create(rewriter, op.getLoc(), numGroups);
rewriter.eraseOp(op);
return success();
}
@@ -771,8 +773,8 @@ struct NVGPUMBarrierCreateLowering
SymbolTable symbolTable(moduleOp);
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(&moduleOp.front());
- auto global = rewriter.create<memref::GlobalOp>(
- funcOp->getLoc(), "__mbarrier",
+ auto global = memref::GlobalOp::create(
+ rewriter, funcOp->getLoc(), "__mbarrier",
/*sym_visibility=*/rewriter.getStringAttr("private"),
/*type=*/barrierType,
/*initial_value=*/ElementsAttr(),
@@ -974,7 +976,7 @@ struct NVGPUMBarrierTryWaitParityLowering
adaptor.getMbarId(), rewriter);
Value ticks = truncToI32(b, adaptor.getTicks());
Value phase =
- b.create<LLVM::ZExtOp>(b.getI32Type(), adaptor.getPhaseParity());
+ LLVM::ZExtOp::create(b, b.getI32Type(), adaptor.getPhaseParity());
if (isMbarrierShared(op.getBarriers().getType())) {
rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParitySharedOp>(
@@ -1063,16 +1065,16 @@ struct NVGPUGenerateWarpgroupDescriptorLowering
auto ti64 = b.getIntegerType(64);
auto makeConst = [&](uint64_t index) -> Value {
- return b.create<LLVM::ConstantOp>(ti64, b.getI64IntegerAttr(index));
+ return LLVM::ConstantOp::create(b, ti64, b.getI64IntegerAttr(index));
};
auto shiftLeft = [&](Value value, unsigned shift) -> Value {
- return b.create<LLVM::ShlOp>(ti64, value, makeConst(shift));
+ return LLVM::ShlOp::create(b, ti64, value, makeConst(shift));
};
auto shiftRight = [&](Value value, unsigned shift) -> Value {
- return b.create<LLVM::LShrOp>(ti64, value, makeConst(shift));
+ return LLVM::LShrOp::create(b, ti64, value, makeConst(shift));
};
auto insertBit = [&](Value desc, Value val, int startBit) {
- return b.create<LLVM::OrOp>(ti64, desc, shiftLeft(val, startBit));
+ return LLVM::OrOp::create(b, ti64, desc, shiftLeft(val, startBit));
};
int64_t sizeN = op.getTensorMap().getType().getTensor().getDimSize(0);
@@ -1086,7 +1088,7 @@ struct NVGPUGenerateWarpgroupDescriptorLowering
Value baseAddr = getStridedElementPtr(
rewriter, op->getLoc(), cast<MemRefType>(op.getTensor().getType()),
adaptor.getTensor(), {});
- Value basePtr = b.create<LLVM::PtrToIntOp>(ti64, baseAddr);
+ Value basePtr = LLVM::PtrToIntOp::create(b, ti64, baseAddr);
// Just use 14 bits for base address
Value basePtr14bit = shiftRight(shiftLeft(basePtr, 46), 50);
@@ -1118,8 +1120,8 @@ struct NVGPUGenerateWarpgroupDescriptorLowering
};
static Value makeI64Const(ImplicitLocOpBuilder &b, int32_t index) {
- return b.create<LLVM::ConstantOp>(b.getIntegerType(64),
- b.getI32IntegerAttr(index));
+ return LLVM::ConstantOp::create(b, b.getIntegerType(64),
+ b.getI32IntegerAttr(index));
}
/// Returns a Value that holds data type enum that is expected by CUDA driver.
@@ -1182,12 +1184,12 @@ struct NVGPUTmaCreateDescriptorOpLowering
auto promotedOperands = getTypeConverter()->promoteOperands(
b.getLoc(), op->getOperands(), adaptor.getOperands(), b);
- Value boxArrayPtr = b.create<LLVM::AllocaOp>(llvmPointerType, llvmInt64Type,
- makeI64Const(b, 5));
+ Value boxArrayPtr = LLVM::AllocaOp::create(
+ b, llvmPointerType, llvmInt64Type, makeI64Const(b, 5));
for (auto [index, value] : llvm::enumerate(adaptor.getBoxDimensions())) {
- Value gep = b.create<LLVM::GEPOp>(llvmPointerType, llvmPointerType,
- boxArrayPtr, makeI64Const(b, index));
- b.create<LLVM::StoreOp>(value, gep);
+ Value gep = LLVM::GEPOp::create(b, llvmPointerType, llvmPointerType,
+ boxArrayPtr, makeI64Const(b, index));
+ LLVM::StoreOp::create(b, value, gep);
}
nvgpu::TensorMapDescriptorType desc = op.getTensorMap().getType();
@@ -1337,7 +1339,7 @@ struct NVGPUWarpgroupMmaOpLowering
/// Basic function to generate Add
Value makeAdd(Value lhs, Value rhs) {
- return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs);
+ return LLVM::AddOp::create(b, lhs.getType(), lhs, rhs);
};
/// Moves the descriptor pointer of matrix-A for the next wgmma instruction.
@@ -1430,29 +1432,30 @@ struct NVGPUWarpgroupMmaOpLowering
auto overflow = NVVM::MMAIntOverflowAttr::get(
op->getContext(), NVVM::MMAIntOverflow::wrapped);
- return b.create<NVVM::WgmmaMmaAsyncOp>(
- matrixC.getType(), matrixC, descriptorA, descriptorB, shape, itypeA,
- itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB,
+ return NVVM::WgmmaMmaAsyncOp::create(
+ b, matrixC.getType(), matrixC, descriptorA, descriptorB, shape,
+ itypeA, itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB,
overflow);
}
/// Generates multiple wgmma instructions to complete the given GEMM shape
Value generateWgmmaGroup() {
Value wgmmaResult =
- b.create<LLVM::PoisonOp>(adaptor.getMatrixC().getType());
+ LLVM::PoisonOp::create(b, adaptor.getMatrixC().getType());
// Perform GEMM
SmallVector<Value> wgmmaResults;
for (int i = 0; i < iterationM; ++i) {
- Value matrixC = b.create<LLVM::ExtractValueOp>(adaptor.getMatrixC(), i);
+ Value matrixC =
+ LLVM::ExtractValueOp::create(b, adaptor.getMatrixC(), i);
for (int j = 0; j < iterationN; ++j)
for (int k = 0; k < iterationK; ++k)
matrixC = generateWgmma(i, j, k, matrixC);
wgmmaResults.push_back(matrixC);
}
for (auto [idx, matrix] : llvm::enumerate(wgmmaResults)) {
- wgmmaResult = b.create<LLVM::InsertValueOp>(wgmmaResult.getType(),
- wgmmaResult, matrix, idx);
+ wgmmaResult = LLVM::InsertValueOp::create(b, wgmmaResult.getType(),
+ wgmmaResult, matrix, idx);
}
return wgmmaResult;
}
@@ -1486,10 +1489,10 @@ struct NVGPUWarpgroupMmaOpLowering
/// (WgmmaGroupSyncAlignedOp) for group synchronization
/// (WgmmaWaitGroupSyncOp) after the instructions.
Value generateWarpgroupMma() {
- b.create<NVVM::WgmmaFenceAlignedOp>();
+ NVVM::WgmmaFenceAlignedOp::create(b);
Value wgmmaResult = generateWgmmaGroup();
- b.create<NVVM::WgmmaGroupSyncAlignedOp>();
- b.create<NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup());
+ NVVM::WgmmaGroupSyncAlignedOp::create(b);
+ NVVM::WgmmaWaitGroupSyncOp::create(b, op.getWaitGroup());
return wgmmaResult;
}
};
@@ -1557,7 +1560,7 @@ struct NVGPUWarpgroupMmaStoreOpLowering
Type i32 = b.getI32Type();
auto makeConst = [&](int32_t index) -> Value {
- return b.create<LLVM::ConstantOp>(i32, b.getI32IntegerAttr(index));
+ return LLVM::ConstantOp::create(b, i32, b.getI32IntegerAttr(index));
};
Value c1 = makeConst(1);
Value c2 = makeConst(2);
@@ -1567,29 +1570,29 @@ struct NVGPUWarpgroupMmaStoreOpLowering
Value warpSize = makeConst(kWarpSize);
auto makeMul = [&](Value lhs, Value rhs) -> Value {
- return b.create<LLVM::MulOp>(lhs.getType(), lhs, rhs);
+ return LLVM::MulOp::create(b, lhs.getType(), lhs, rhs);
};
auto makeAdd = [&](Value lhs, Value rhs) -> Value {
- return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs);
+ return LLVM::AddOp::create(b, lhs.getType(), lhs, rhs);
};
auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y,
TypedValue<::mlir::MemRefType> memref) {
Type it = b.getIndexType();
- Value idx = b.create<arith::IndexCastOp>(it, x);
- Value idy0 = b.create<arith::IndexCastOp>(it, y);
- Value idy1 = b.create<arith::IndexCastOp>(it, makeAdd(y, c1));
- Value d0 = b.create<LLVM::ExtractValueOp>(wgmmaResult, i);
- Value d1 = b.create<LLVM::ExtractValueOp>(wgmmaResult, i + 1);
- b.create<memref::StoreOp>(d0, memref, ValueRange{idx, idy0});
- b.create<memref::StoreOp>(d1, memref, ValueRange{idx, idy1});
+ Value idx = arith::IndexCastOp::create(b, it, x);
+ Value idy0 = arith::IndexCastOp::create(b, it, y);
+ Value idy1 = arith::IndexCastOp::create(b, it, makeAdd(y, c1));
+ Value d0 = LLVM::ExtractValueOp::create(b, wgmmaResult, i);
+ Value d1 = LLVM::ExtractValueOp::create(b, wgmmaResult, i + 1);
+ memref::StoreOp::create(b, d0, memref, ValueRange{idx, idy0});
+ memref::StoreOp::create(b, d1, memref, ValueRange{idx, idy1});
};
- Value tidx = b.create<NVVM::ThreadIdXOp>(i32);
- Value laneId = b.create<LLVM::URemOp>(i32, tidx, warpSize);
- Value warpId = b.create<LLVM::UDivOp>(i32, tidx, warpSize);
- Value lane4Id = b.create<LLVM::UDivOp>(i32, laneId, c4);
- Value lane4modId = b.create<LLVM::URemOp>(i32, laneId, c4);
+ Value tidx = NVVM::ThreadIdXOp::create(b, i32);
+ Value laneId = LLVM::URemOp::create(b, i32, tidx, warpSize);
+ Value warpId = LLVM::UDivOp::create(b, i32, tidx, warpSize);
+ Value lane4Id = LLVM::UDivOp::create(b, i32, laneId, c4);
+ Value lane4modId = LLVM::URemOp::create(b, i32, laneId, c4);
Value tj = makeMul(lane4modId, c2);
Value ti = makeAdd(lane4Id, makeMul(warpId, c16));
@@ -1626,7 +1629,8 @@ struct NVGPUWarpgroupMmaStoreOpLowering
auto stype = cast<LLVM::LLVMStructType>(matriDValue.getType());
for (auto [idx, matrixD] : llvm::enumerate(stype.getBody())) {
auto structType = cast<LLVM::LLVMStructType>(matrixD);
- Value innerStructValue = b.create<LLVM::ExtractValueOp>(matriDValue, idx);
+ Value innerStructValue =
+ LLVM::ExtractValueOp::create(b, matriDValue, idx);
storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset);
offset += structType.getBody().size();
}
@@ -1648,23 +1652,23 @@ struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
Type elemType = cast<LLVM::LLVMStructType>(packStructType.getBody().front())
.getBody()
.front();
- Value zero = b.create<LLVM::ConstantOp>(elemType, b.getZeroAttr(elemType));
- Value packStruct = b.create<LLVM::PoisonOp>(packStructType);
+ Value zero = LLVM::ConstantOp::create(b, elemType, b.getZeroAttr(elemType));
+ Value packStruct = LLVM::PoisonOp::create(b, packStructType);
SmallVector<Value> innerStructs;
// Unpack the structs and set all values to zero
for (auto [idx, s] : llvm::enumerate(packStructType.getBody())) {
auto structType = cast<LLVM::LLVMStructType>(s);
- Value structValue = b.create<LLVM::ExtractValueOp>(packStruct, idx);
+ Value structValue = LLVM::ExtractValueOp::create(b, packStruct, idx);
for (unsigned i = 0; i < structType.getBody().size(); ++i) {
- structValue = b.create<LLVM::InsertValueOp>(
- structType, structValue, zero, ArrayRef<int64_t>({i}));
+ structValue = LLVM::InsertValueOp::create(b, structType, structValue,
+ zero, ArrayRef<int64_t>({i}));
}
innerStructs.push_back(structValue);
}
// Pack the inner structs into a single struct
for (auto [idx, matrix] : llvm::enumerate(innerStructs)) {
- packStruct = b.create<LLVM::InsertValueOp>(packStruct.getType(),
- packStruct, matrix, idx);
+ packStruct = LLVM::InsertValueOp::create(b, packStruct.getType(),
+ packStruct, matrix, idx);
}
rewriter.replaceOp(op, packStruct);
return success();
@@ -1681,7 +1685,7 @@ struct NVGPUTmaFenceOpLowering
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
auto i32Ty = b.getI32Type();
Value tensormapSize =
- b.create<LLVM::ConstantOp>(i32Ty, rewriter.getI32IntegerAttr(128));
+ LLVM::ConstantOp::create(b, i32Ty, rewriter.getI32IntegerAttr(128));
auto memscope =
NVVM::MemScopeKindAttr::get(ctx, ::mlir::NVVM::MemScopeKind::SYS);
@@ -1716,13 +1720,13 @@ struct NVGPURcpOpLowering : public ConvertOpToLLVMPattern<nvgpu::RcpOp> {
VectorType inTy = op.getIn().getType();
// apply rcp.approx.ftz.f on each element in vector.
auto convert1DVec = [&](Type llvm1DVectorTy, Value inVec) {
- Value ret1DVec = b.create<LLVM::PoisonOp>(llvm1DVectorTy);
+ Value ret1DVec = LLVM::PoisonOp::create(b, llvm1DVectorTy);
int numElems = llvm::cast<VectorType>(llvm1DVectorTy).getNumElements();
for (int i = 0; i < numElems; i++) {
- Value idx = b.create<LLVM::ConstantOp>(i64Ty, b.getI64IntegerAttr(i));
- Value elem = b.create<LLVM::ExtractElementOp>(inVec, idx);
- Value dst = b.create<NVVM::RcpApproxFtzF32Op>(f32Ty, elem);
- ret1DVec = b.create<LLVM::InsertElementOp>(ret1DVec, dst, idx);
+ Value idx = LLVM::ConstantOp::create(b, i64Ty, b.getI64IntegerAttr(i));
+ Value elem = LLVM::ExtractElementOp::create(b, inVec, idx);
+ Value dst = NVVM::RcpApproxFtzF32Op::create(b, f32Ty, elem);
+ ret1DVec = LLVM::InsertElementOp::create(b, ret1DVec, dst, idx);
}
return ret1DVec;
};
diff --git a/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp b/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp
index 479725aae8afd..f5b3689c88d26 100644
--- a/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp
+++ b/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp
@@ -39,8 +39,8 @@ class ExpandIfCondition : public OpRewritePattern<OpTy> {
IntegerAttr constAttr;
if (!matchPattern(op.getIfCond(), m_Constant(&constAttr))) {
- auto ifOp = rewriter.create<scf::IfOp>(op.getLoc(), TypeRange(),
- op.getIfCond(), false);
+ auto ifOp = scf::IfOp::create(rewriter, op.getLoc(), TypeRange(),
+ op.getIfCond(), false);
rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
auto thenBodyBuilder = ifOp.getThenBodyBuilder(rewriter.getListener());
thenBodyBuilder.clone(*op.getOperation());
diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
index 7ac9687c4eeda..021e31a8ecd97 100644
--- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
+++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
@@ -95,8 +95,8 @@ struct OpenMPOpConversion : public ConvertOpToLLVMPattern<T> {
}
// Create new operation.
- auto newOp = rewriter.create<T>(op.getLoc(), resTypes, convertedOperands,
- convertedAttrs);
+ auto newOp = T::create(rewriter, op.getLoc(), resTypes, convertedOperands,
+ convertedAttrs);
// Translate regions.
for (auto [originalRegion, convertedRegion] :
diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
index 7d20109b3db59..b711e33cfc0d6 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
@@ -196,7 +196,7 @@ Block *PatternLowering::generateMatcher(MatcherNode &node, Region ®ion,
// finalize.
if (isa<ExitNode>(node)) {
builder.setInsertionPointToEnd(block);
- builder.create<pdl_interp::FinalizeOp>(matcherFunc.getLoc());
+ pdl_interp::FinalizeOp::create(builder, matcherFunc.getLoc());
return block;
}
@@ -272,8 +272,8 @@ Value PatternLowering::getValueAt(Block *¤tBlock, Position *pos) {
auto *operationPos = cast<OperationPosition>(pos);
if (operationPos->isOperandDefiningOp())
// Standard (downward) traversal which directly follows the defining op.
- value = builder.create<pdl_interp::GetDefiningOpOp>(
- loc, builder.getType<pdl::OperationType>(), parentVal);
+ value = pdl_interp::GetDefiningOpOp::create(
+ builder, loc, builder.getType<pdl::OperationType>(), parentVal);
else
// A passthrough operation position.
value = parentVal;
@@ -287,23 +287,23 @@ Value PatternLowering::getValueAt(Block *¤tBlock, Position *pos) {
// requested to use a representative value (e.g., upward traversal).
if (isa<pdl::RangeType>(parentVal.getType()) &&
usersPos->useRepresentative())
- value = builder.create<pdl_interp::ExtractOp>(loc, parentVal, 0);
+ value = pdl_interp::ExtractOp::create(builder, loc, parentVal, 0);
else
value = parentVal;
// The second operation retrieves the users.
- value = builder.create<pdl_interp::GetUsersOp>(loc, value);
+ value = pdl_interp::GetUsersOp::create(builder, loc, value);
break;
}
case Predicates::ForEachPos: {
assert(!failureBlockStack.empty() && "expected valid failure block");
- auto foreach = builder.create<pdl_interp::ForEachOp>(
- loc, parentVal, failureBlockStack.back(), /*initLoop=*/true);
+ auto foreach = pdl_interp::ForEachOp::create(
+ builder, loc, parentVal, failureBlockStack.back(), /*initLoop=*/true);
value = foreach.getLoopVariable();
// Create the continuation block.
Block *continueBlock = builder.createBlock(&foreach.getRegion());
- builder.create<pdl_interp::ContinueOp>(loc);
+ pdl_interp::ContinueOp::create(builder, loc);
failureBlockStack.push_back(continueBlock);
currentBlock = &foreach.getRegion().front();
@@ -311,62 +311,64 @@ Value PatternLowering::getValueAt(Block *¤tBlock, Position *pos) {
}
case Predicates::OperandPos: {
auto *operandPos = cast<OperandPosition>(pos);
- value = builder.create<pdl_interp::GetOperandOp>(
- loc, builder.getType<pdl::ValueType>(), parentVal,
+ value = pdl_interp::GetOperandOp::create(
+ builder, loc, builder.getType<pdl::ValueType>(), parentVal,
operandPos->getOperandNumber());
break;
}
case Predicates::OperandGroupPos: {
auto *operandPos = cast<OperandGroupPosition>(pos);
Type valueTy = builder.getType<pdl::ValueType>();
- value = builder.create<pdl_interp::GetOperandsOp>(
- loc, operandPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
+ value = pdl_interp::GetOperandsOp::create(
+ builder, loc,
+ operandPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
parentVal, operandPos->getOperandGroupNumber());
break;
}
case Predicates::AttributePos: {
auto *attrPos = cast<AttributePosition>(pos);
- value = builder.create<pdl_interp::GetAttributeOp>(
- loc, builder.getType<pdl::AttributeType>(), parentVal,
+ value = pdl_interp::GetAttributeOp::create(
+ builder, loc, builder.getType<pdl::AttributeType>(), parentVal,
attrPos->getName().strref());
break;
}
case Predicates::TypePos: {
if (isa<pdl::AttributeType>(parentVal.getType()))
- value = builder.create<pdl_interp::GetAttributeTypeOp>(loc, parentVal);
+ value = pdl_interp::GetAttributeTypeOp::create(builder, loc, parentVal);
else
- value = builder.create<pdl_interp::GetValueTypeOp>(loc, parentVal);
+ value = pdl_interp::GetValueTypeOp::create(builder, loc, parentVal);
break;
}
case Predicates::ResultPos: {
auto *resPos = cast<ResultPosition>(pos);
- value = builder.create<pdl_interp::GetResultOp>(
- loc, builder.getType<pdl::ValueType>(), parentVal,
+ value = pdl_interp::GetResultOp::create(
+ builder, loc, builder.getType<pdl::ValueType>(), parentVal,
resPos->getResultNumber());
break;
}
case Predicates::ResultGroupPos: {
auto *resPos = cast<ResultGroupPosition>(pos);
Type valueTy = builder.getType<pdl::ValueType>();
- value = builder.create<pdl_interp::GetResultsOp>(
- loc, resPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
+ value = pdl_interp::GetResultsOp::create(
+ builder, loc,
+ resPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
parentVal, resPos->getResultGroupNumber());
break;
}
case Predicates::AttributeLiteralPos: {
auto *attrPos = cast<AttributeLiteralPosition>(pos);
- value =
- builder.create<pdl_interp::CreateAttributeOp>(loc, attrPos->getValue());
+ value = pdl_interp::CreateAttributeOp::create(builder, loc,
+ attrPos->getValue());
break;
}
case Predicates::TypeLiteralPos: {
auto *typePos = cast<TypeLiteralPosition>(pos);
Attribute rawTypeAttr = typePos->getValue();
if (TypeAttr typeAttr = dyn_cast<TypeAttr>(rawTypeAttr))
- value = builder.create<pdl_interp::CreateTypeOp>(loc, typeAttr);
+ value = pdl_interp::CreateTypeOp::create(builder, loc, typeAttr);
else
- value = builder.create<pdl_interp::CreateTypesOp>(
- loc, cast<ArrayAttr>(rawTypeAttr));
+ value = pdl_interp::CreateTypesOp::create(builder, loc,
+ cast<ArrayAttr>(rawTypeAttr));
break;
}
case Predicates::ConstraintResultPos: {
@@ -413,56 +415,59 @@ void PatternLowering::generate(BoolNode *boolNode, Block *¤tBlock,
Predicates::Kind kind = question->getKind();
switch (kind) {
case Predicates::IsNotNullQuestion:
- builder.create<pdl_interp::IsNotNullOp>(loc, val, success, failure);
+ pdl_interp::IsNotNullOp::create(builder, loc, val, success, failure);
break;
case Predicates::OperationNameQuestion: {
auto *opNameAnswer = cast<OperationNameAnswer>(answer);
- builder.create<pdl_interp::CheckOperationNameOp>(
- loc, val, opNameAnswer->getValue().getStringRef(), success, failure);
+ pdl_interp::CheckOperationNameOp::create(
+ builder, loc, val, opNameAnswer->getValue().getStringRef(), success,
+ failure);
break;
}
case Predicates::TypeQuestion: {
auto *ans = cast<TypeAnswer>(answer);
if (isa<pdl::RangeType>(val.getType()))
- builder.create<pdl_interp::CheckTypesOp>(
- loc, val, llvm::cast<ArrayAttr>(ans->getValue()), success, failure);
+ pdl_interp::CheckTypesOp::create(builder, loc, val,
+ llvm::cast<ArrayAttr>(ans->getValue()),
+ success, failure);
else
- builder.create<pdl_interp::CheckTypeOp>(
- loc, val, llvm::cast<TypeAttr>(ans->getValue()), success, failure);
+ pdl_interp::CheckTypeOp::create(builder, loc, val,
+ llvm::cast<TypeAttr>(ans->getValue()),
+ success, failure);
break;
}
case Predicates::AttributeQuestion: {
auto *ans = cast<AttributeAnswer>(answer);
- builder.create<pdl_interp::CheckAttributeOp>(loc, val, ans->getValue(),
- success, failure);
+ pdl_interp::CheckAttributeOp::create(builder, loc, val, ans->getValue(),
+ success, failure);
break;
}
case Predicates::OperandCountAtLeastQuestion:
case Predicates::OperandCountQuestion:
- builder.create<pdl_interp::CheckOperandCountOp>(
- loc, val, cast<UnsignedAnswer>(answer)->getValue(),
+ pdl_interp::CheckOperandCountOp::create(
+ builder, loc, val, cast<UnsignedAnswer>(answer)->getValue(),
/*compareAtLeast=*/kind == Predicates::OperandCountAtLeastQuestion,
success, failure);
break;
case Predicates::ResultCountAtLeastQuestion:
case Predicates::ResultCountQuestion:
- builder.create<pdl_interp::CheckResultCountOp>(
- loc, val, cast<UnsignedAnswer>(answer)->getValue(),
+ pdl_interp::CheckResultCountOp::create(
+ builder, loc, val, cast<UnsignedAnswer>(answer)->getValue(),
/*compareAtLeast=*/kind == Predicates::ResultCountAtLeastQuestion,
success, failure);
break;
case Predicates::EqualToQuestion: {
bool trueAnswer = isa<TrueAnswer>(answer);
- builder.create<pdl_interp::AreEqualOp>(loc, val, args.front(),
- trueAnswer ? success : failure,
- trueAnswer ? failure : success);
+ pdl_interp::AreEqualOp::create(builder, loc, val, args.front(),
+ trueAnswer ? success : failure,
+ trueAnswer ? failure : success);
break;
}
case Predicates::ConstraintQuestion: {
auto *cstQuestion = cast<ConstraintQuestion>(question);
- auto applyConstraintOp = builder.create<pdl_interp::ApplyConstraintOp>(
- loc, cstQuestion->getResultTypes(), cstQuestion->getName(), args,
- cstQuestion->getIsNegated(), success, failure);
+ auto applyConstraintOp = pdl_interp::ApplyConstraintOp::create(
+ builder, loc, cstQuestion->getResultTypes(), cstQuestion->getName(),
+ args, cstQuestion->getIsNegated(), success, failure);
constraintOpMap.insert({cstQuestion, applyConstraintOp});
break;
@@ -487,7 +492,7 @@ static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder,
blocks.push_back(it.second);
values.push_back(cast<PredT>(it.first)->getValue());
}
- builder.create<OpT>(val.getLoc(), val, values, defaultDest, blocks);
+ OpT::create(builder, val.getLoc(), val, values, defaultDest, blocks);
}
void PatternLowering::generate(SwitchNode *switchNode, Block *currentBlock,
@@ -536,12 +541,14 @@ void PatternLowering::generate(SwitchNode *switchNode, Block *currentBlock,
unsigned ans = cast<UnsignedAnswer>(child.first)->getValue();
switch (kind) {
case Predicates::OperandCountAtLeastQuestion:
- builder.create<pdl_interp::CheckOperandCountOp>(
- loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest);
+ pdl_interp::CheckOperandCountOp::create(builder, loc, val, ans,
+ /*compareAtLeast=*/true,
+ childBlock, defaultDest);
break;
case Predicates::ResultCountAtLeastQuestion:
- builder.create<pdl_interp::CheckResultCountOp>(
- loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest);
+ pdl_interp::CheckResultCountOp::create(builder, loc, val, ans,
+ /*compareAtLeast=*/true,
+ childBlock, defaultDest);
break;
default:
llvm_unreachable("Generating invalid AtLeast operation");
@@ -619,8 +626,8 @@ void PatternLowering::generate(SuccessNode *successNode, Block *¤tBlock) {
rootKindAttr = builder.getStringAttr(*rootKind);
builder.setInsertionPointToEnd(currentBlock);
- auto matchOp = builder.create<pdl_interp::RecordMatchOp>(
- pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(),
+ auto matchOp = pdl_interp::RecordMatchOp::create(
+ builder, pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(),
rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.getBenefitAttr(),
failureBlockStack.back());
@@ -632,8 +639,8 @@ void PatternLowering::generate(SuccessNode *successNode, Block *¤tBlock) {
SymbolRefAttr PatternLowering::generateRewriter(
pdl::PatternOp pattern, SmallVectorImpl<Position *> &usedMatchValues) {
builder.setInsertionPointToEnd(rewriterModule.getBody());
- auto rewriterFunc = builder.create<pdl_interp::FuncOp>(
- pattern.getLoc(), "pdl_generated_rewriter",
+ auto rewriterFunc = pdl_interp::FuncOp::create(
+ builder, pattern.getLoc(), "pdl_generated_rewriter",
builder.getFunctionType({}, {}));
rewriterSymbolTable.insert(rewriterFunc);
@@ -651,18 +658,18 @@ SymbolRefAttr PatternLowering::generateRewriter(
Operation *oldOp = oldValue.getDefiningOp();
if (pdl::AttributeOp attrOp = dyn_cast<pdl::AttributeOp>(oldOp)) {
if (Attribute value = attrOp.getValueAttr()) {
- return newValue = builder.create<pdl_interp::CreateAttributeOp>(
- attrOp.getLoc(), value);
+ return newValue = pdl_interp::CreateAttributeOp::create(
+ builder, attrOp.getLoc(), value);
}
} else if (pdl::TypeOp typeOp = dyn_cast<pdl::TypeOp>(oldOp)) {
if (TypeAttr type = typeOp.getConstantTypeAttr()) {
- return newValue = builder.create<pdl_interp::CreateTypeOp>(
- typeOp.getLoc(), type);
+ return newValue = pdl_interp::CreateTypeOp::create(
+ builder, typeOp.getLoc(), type);
}
} else if (pdl::TypesOp typeOp = dyn_cast<pdl::TypesOp>(oldOp)) {
if (ArrayAttr type = typeOp.getConstantTypesAttr()) {
- return newValue = builder.create<pdl_interp::CreateTypesOp>(
- typeOp.getLoc(), typeOp.getType(), type);
+ return newValue = pdl_interp::CreateTypesOp::create(
+ builder, typeOp.getLoc(), typeOp.getType(), type);
}
}
@@ -684,8 +691,9 @@ SymbolRefAttr PatternLowering::generateRewriter(
auto mappedArgs =
llvm::map_range(rewriter.getExternalArgs(), mapRewriteValue);
args.append(mappedArgs.begin(), mappedArgs.end());
- builder.create<pdl_interp::ApplyRewriteOp>(
- rewriter.getLoc(), /*resultTypes=*/TypeRange(), rewriteName, args);
+ pdl_interp::ApplyRewriteOp::create(builder, rewriter.getLoc(),
+ /*resultTypes=*/TypeRange(), rewriteName,
+ args);
} else {
// Otherwise this is a dag rewriter defined using PDL operations.
for (Operation &rewriteOp : *rewriter.getBody()) {
@@ -703,7 +711,7 @@ SymbolRefAttr PatternLowering::generateRewriter(
llvm::to_vector<8>(rewriterFunc.front().getArgumentTypes()),
/*results=*/{}));
- builder.create<pdl_interp::FinalizeOp>(rewriter.getLoc());
+ pdl_interp::FinalizeOp::create(builder, rewriter.getLoc());
return SymbolRefAttr::get(
builder.getContext(),
pdl_interp::PDLInterpDialect::getRewriterModuleName(),
@@ -716,9 +724,9 @@ void PatternLowering::generateRewriter(
SmallVector<Value, 2> arguments;
for (Value argument : rewriteOp.getArgs())
arguments.push_back(mapRewriteValue(argument));
- auto interpOp = builder.create<pdl_interp::ApplyRewriteOp>(
- rewriteOp.getLoc(), rewriteOp.getResultTypes(), rewriteOp.getNameAttr(),
- arguments);
+ auto interpOp = pdl_interp::ApplyRewriteOp::create(
+ builder, rewriteOp.getLoc(), rewriteOp.getResultTypes(),
+ rewriteOp.getNameAttr(), arguments);
for (auto it : llvm::zip(rewriteOp.getResults(), interpOp.getResults()))
rewriteValues[std::get<0>(it)] = std::get<1>(it);
}
@@ -726,16 +734,16 @@ void PatternLowering::generateRewriter(
void PatternLowering::generateRewriter(
pdl::AttributeOp attrOp, DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue) {
- Value newAttr = builder.create<pdl_interp::CreateAttributeOp>(
- attrOp.getLoc(), attrOp.getValueAttr());
+ Value newAttr = pdl_interp::CreateAttributeOp::create(
+ builder, attrOp.getLoc(), attrOp.getValueAttr());
rewriteValues[attrOp] = newAttr;
}
void PatternLowering::generateRewriter(
pdl::EraseOp eraseOp, DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue) {
- builder.create<pdl_interp::EraseOp>(eraseOp.getLoc(),
- mapRewriteValue(eraseOp.getOpValue()));
+ pdl_interp::EraseOp::create(builder, eraseOp.getLoc(),
+ mapRewriteValue(eraseOp.getOpValue()));
}
void PatternLowering::generateRewriter(
@@ -756,9 +764,9 @@ void PatternLowering::generateRewriter(
// Create the new operation.
Location loc = operationOp.getLoc();
- Value createdOp = builder.create<pdl_interp::CreateOperationOp>(
- loc, *operationOp.getOpName(), types, hasInferredResultTypes, operands,
- attributes, operationOp.getAttributeValueNames());
+ Value createdOp = pdl_interp::CreateOperationOp::create(
+ builder, loc, *operationOp.getOpName(), types, hasInferredResultTypes,
+ operands, attributes, operationOp.getAttributeValueNames());
rewriteValues[operationOp.getOp()] = createdOp;
// Generate accesses for any results that have their types constrained.
@@ -768,8 +776,8 @@ void PatternLowering::generateRewriter(
if (resultTys.size() == 1 && isa<pdl::RangeType>(resultTys[0].getType())) {
Value &type = rewriteValues[resultTys[0]];
if (!type) {
- auto results = builder.create<pdl_interp::GetResultsOp>(loc, createdOp);
- type = builder.create<pdl_interp::GetValueTypeOp>(loc, results);
+ auto results = pdl_interp::GetResultsOp::create(builder, loc, createdOp);
+ type = pdl_interp::GetValueTypeOp::create(builder, loc, results);
}
return;
}
@@ -789,12 +797,13 @@ void PatternLowering::generateRewriter(
// groups because the exact index of the result is not statically known.
Value resultVal;
if (seenVariableLength)
- resultVal = builder.create<pdl_interp::GetResultsOp>(
- loc, isVariadic ? valueRangeTy : valueTy, createdOp, it.index());
+ resultVal = pdl_interp::GetResultsOp::create(
+ builder, loc, isVariadic ? valueRangeTy : valueTy, createdOp,
+ it.index());
else
- resultVal = builder.create<pdl_interp::GetResultOp>(
- loc, valueTy, createdOp, it.index());
- type = builder.create<pdl_interp::GetValueTypeOp>(loc, resultVal);
+ resultVal = pdl_interp::GetResultOp::create(builder, loc, valueTy,
+ createdOp, it.index());
+ type = pdl_interp::GetValueTypeOp::create(builder, loc, resultVal);
}
}
@@ -804,8 +813,8 @@ void PatternLowering::generateRewriter(
SmallVector<Value, 4> replOperands;
for (Value operand : rangeOp.getArguments())
replOperands.push_back(mapRewriteValue(operand));
- rewriteValues[rangeOp] = builder.create<pdl_interp::CreateRangeOp>(
- rangeOp.getLoc(), rangeOp.getType(), replOperands);
+ rewriteValues[rangeOp] = pdl_interp::CreateRangeOp::create(
+ builder, rangeOp.getLoc(), rangeOp.getType(), replOperands);
}
void PatternLowering::generateRewriter(
@@ -820,8 +829,8 @@ void PatternLowering::generateRewriter(
// Don't use replace if we know the replaced operation has no results.
auto opOp = replaceOp.getOpValue().getDefiningOp<pdl::OperationOp>();
if (!opOp || !opOp.getTypeValues().empty()) {
- replOperands.push_back(builder.create<pdl_interp::GetResultsOp>(
- replOp.getLoc(), mapRewriteValue(replOp)));
+ replOperands.push_back(pdl_interp::GetResultsOp::create(
+ builder, replOp.getLoc(), mapRewriteValue(replOp)));
}
} else {
for (Value operand : replaceOp.getReplValues())
@@ -830,29 +839,29 @@ void PatternLowering::generateRewriter(
// If there are no replacement values, just create an erase instead.
if (replOperands.empty()) {
- builder.create<pdl_interp::EraseOp>(
- replaceOp.getLoc(), mapRewriteValue(replaceOp.getOpValue()));
+ pdl_interp::EraseOp::create(builder, replaceOp.getLoc(),
+ mapRewriteValue(replaceOp.getOpValue()));
return;
}
- builder.create<pdl_interp::ReplaceOp>(replaceOp.getLoc(),
- mapRewriteValue(replaceOp.getOpValue()),
- replOperands);
+ pdl_interp::ReplaceOp::create(builder, replaceOp.getLoc(),
+ mapRewriteValue(replaceOp.getOpValue()),
+ replOperands);
}
void PatternLowering::generateRewriter(
pdl::ResultOp resultOp, DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue) {
- rewriteValues[resultOp] = builder.create<pdl_interp::GetResultOp>(
- resultOp.getLoc(), builder.getType<pdl::ValueType>(),
+ rewriteValues[resultOp] = pdl_interp::GetResultOp::create(
+ builder, resultOp.getLoc(), builder.getType<pdl::ValueType>(),
mapRewriteValue(resultOp.getParent()), resultOp.getIndex());
}
void PatternLowering::generateRewriter(
pdl::ResultsOp resultOp, DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue) {
- rewriteValues[resultOp] = builder.create<pdl_interp::GetResultsOp>(
- resultOp.getLoc(), resultOp.getType(),
+ rewriteValues[resultOp] = pdl_interp::GetResultsOp::create(
+ builder, resultOp.getLoc(), resultOp.getType(),
mapRewriteValue(resultOp.getParent()), resultOp.getIndex());
}
@@ -863,7 +872,7 @@ void PatternLowering::generateRewriter(
// type.
if (TypeAttr typeAttr = typeOp.getConstantTypeAttr()) {
rewriteValues[typeOp] =
- builder.create<pdl_interp::CreateTypeOp>(typeOp.getLoc(), typeAttr);
+ pdl_interp::CreateTypeOp::create(builder, typeOp.getLoc(), typeAttr);
}
}
@@ -873,8 +882,8 @@ void PatternLowering::generateRewriter(
// If the type isn't constant, the users (e.g. OperationOp) will resolve this
// type.
if (ArrayAttr typeAttr = typeOp.getConstantTypesAttr()) {
- rewriteValues[typeOp] = builder.create<pdl_interp::CreateTypesOp>(
- typeOp.getLoc(), typeOp.getType(), typeAttr);
+ rewriteValues[typeOp] = pdl_interp::CreateTypesOp::create(
+ builder, typeOp.getLoc(), typeOp.getType(), typeAttr);
}
}
@@ -939,10 +948,10 @@ void PatternLowering::generateOperationResultTypeRewriter(
!replacedOp->isBeforeInBlock(op))
continue;
- Value replacedOpResults = builder.create<pdl_interp::GetResultsOp>(
- replacedOp->getLoc(), mapRewriteValue(replOpVal));
- types.push_back(builder.create<pdl_interp::GetValueTypeOp>(
- replacedOp->getLoc(), replacedOpResults));
+ Value replacedOpResults = pdl_interp::GetResultsOp::create(
+ builder, replacedOp->getLoc(), mapRewriteValue(replOpVal));
+ types.push_back(pdl_interp::GetValueTypeOp::create(
+ builder, replacedOp->getLoc(), replacedOpResults));
return;
}
@@ -985,16 +994,18 @@ void PDLToPDLInterpPass::runOnOperation() {
// Create the main matcher function This function contains all of the match
// related functionality from patterns in the module.
OpBuilder builder = OpBuilder::atBlockBegin(module.getBody());
- auto matcherFunc = builder.create<pdl_interp::FuncOp>(
- module.getLoc(), pdl_interp::PDLInterpDialect::getMatcherFunctionName(),
+ auto matcherFunc = pdl_interp::FuncOp::create(
+ builder, module.getLoc(),
+ pdl_interp::PDLInterpDialect::getMatcherFunctionName(),
builder.getFunctionType(builder.getType<pdl::OperationType>(),
/*results=*/{}),
/*attrs=*/ArrayRef<NamedAttribute>());
// Create a nested module to hold the functions invoked for rewriting the IR
// after a successful match.
- ModuleOp rewriterModule = builder.create<ModuleOp>(
- module.getLoc(), pdl_interp::PDLInterpDialect::getRewriterModuleName());
+ ModuleOp rewriterModule =
+ ModuleOp::create(builder, module.getLoc(),
+ pdl_interp::PDLInterpDialect::getRewriterModuleName());
// Generate the code for the patterns within the module.
PatternLowering generator(matcherFunc, rewriterModule, configMap);
More information about the Mlir-commits
mailing list