[Mlir-commits] [mlir] 563879b - [NFC] Use ConvertOpToLLVMPattern instead of ConvertToLLVMPattern.
Rahul Joshi
llvmlistbot at llvm.org
Thu Dec 10 09:34:11 PST 2020
Author: Rahul Joshi
Date: 2020-12-10T09:33:43-08:00
New Revision: 563879b6f9465982b422a69a901e3d84e7cb7764
URL: https://github.com/llvm/llvm-project/commit/563879b6f9465982b422a69a901e3d84e7cb7764
DIFF: https://github.com/llvm/llvm-project/commit/563879b6f9465982b422a69a901e3d84e7cb7764.diff
LOG: [NFC] Use ConvertOpToLLVMPattern instead of ConvertToLLVMPattern.
- use ConvertOpToLLVMPattern to avoid explicit casting and in most cases the
constructor can be reused to save a few lines of code.
Differential Revision: https://reviews.llvm.org/D92989
Added:
Modified:
mlir/include/mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h
mlir/include/mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h
mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp
mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h
mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h b/mlir/include/mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h
index b7c9d0016d65..948c2a4be6f2 100644
--- a/mlir/include/mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h
+++ b/mlir/include/mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h
@@ -18,8 +18,7 @@ template <typename T> class OperationPass;
/// Populate the given list with patterns that convert from Linalg to LLVM.
void populateLinalgToLLVMConversionPatterns(LLVMTypeConverter &converter,
- OwningRewritePatternList &patterns,
- MLIRContext *ctx);
+ OwningRewritePatternList &patterns);
/// Create a pass to convert Linalg operations to the LLVMIR dialect.
std::unique_ptr<OperationPass<ModuleOp>> createConvertLinalgToLLVMPass();
diff --git a/mlir/include/mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h b/mlir/include/mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h
index ace07ac7223b..4eae84cd0135 100644
--- a/mlir/include/mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h
+++ b/mlir/include/mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h
@@ -19,8 +19,7 @@ class OperationPass;
class OwningRewritePatternList;
/// Populate the given list with patterns that convert from OpenMP to LLVM.
-void populateOpenMPToLLVMConversionPatterns(MLIRContext *context,
- LLVMTypeConverter &converter,
+void populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter,
OwningRewritePatternList &patterns);
/// Create a pass to convert OpenMP operations to the LLVMIR dialect.
diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
index bf41f29749de..7c069c9cd556 100644
--- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
+++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
@@ -565,8 +565,8 @@ class ConvertToLLVMPattern : public ConversionPattern {
template <typename SourceOp>
class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
public:
- ConvertOpToLLVMPattern(LLVMTypeConverter &typeConverter,
- PatternBenefit benefit = 1)
+ explicit ConvertOpToLLVMPattern(LLVMTypeConverter &typeConverter,
+ PatternBenefit benefit = 1)
: ConvertToLLVMPattern(SourceOp::getOperationName(),
&typeConverter.getContext(), typeConverter,
benefit) {}
diff --git a/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp b/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp
index fe06e12c8f21..06a19b057f71 100644
--- a/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp
+++ b/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp
@@ -34,8 +34,7 @@ static Type getSrcVectorElementType(OpTy op) {
/// operands as is, preserve attributes.
template <typename SourceOp, typename TargetOp>
static LogicalResult
-matchAndRewriteOneToOne(const ConvertToLLVMPattern &lowering,
- LLVMTypeConverter &typeConverter, Operation *op,
+matchAndRewriteOneToOne(LLVMTypeConverter &typeConverter, Operation *op,
ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) {
unsigned numResults = op->getNumResults();
@@ -73,71 +72,61 @@ namespace {
// TODO: Patterns are too verbose due to the fact that we have 1 op (e.g.
// MaskRndScaleOp) and
diff erent possible target ops. It would be better to take
// a Functor so that all these conversions become 1-liners.
-struct MaskRndScaleOpPS512Conversion : public ConvertToLLVMPattern {
- explicit MaskRndScaleOpPS512Conversion(MLIRContext *context,
- LLVMTypeConverter &typeConverter)
- : ConvertToLLVMPattern(MaskRndScaleOp::getOperationName(), context,
- typeConverter) {}
+struct MaskRndScaleOpPS512Conversion
+ : public ConvertOpToLLVMPattern<MaskRndScaleOp> {
+ using ConvertOpToLLVMPattern<MaskRndScaleOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(MaskRndScaleOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- if (!getSrcVectorElementType(cast<MaskRndScaleOp>(op)).isF32())
+ if (!getSrcVectorElementType(op).isF32())
return failure();
return matchAndRewriteOneToOne<MaskRndScaleOp,
LLVM::x86_avx512_mask_rndscale_ps_512>(
- *this, *getTypeConverter(), op, operands, rewriter);
+ *getTypeConverter(), op, operands, rewriter);
}
};
-struct MaskRndScaleOpPD512Conversion : public ConvertToLLVMPattern {
- explicit MaskRndScaleOpPD512Conversion(MLIRContext *context,
- LLVMTypeConverter &typeConverter)
- : ConvertToLLVMPattern(MaskRndScaleOp::getOperationName(), context,
- typeConverter) {}
+struct MaskRndScaleOpPD512Conversion
+ : public ConvertOpToLLVMPattern<MaskRndScaleOp> {
+ using ConvertOpToLLVMPattern<MaskRndScaleOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(MaskRndScaleOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- if (!getSrcVectorElementType(cast<MaskRndScaleOp>(op)).isF64())
+ if (!getSrcVectorElementType(op).isF64())
return failure();
return matchAndRewriteOneToOne<MaskRndScaleOp,
LLVM::x86_avx512_mask_rndscale_pd_512>(
- *this, *getTypeConverter(), op, operands, rewriter);
+ *getTypeConverter(), op, operands, rewriter);
}
};
-struct ScaleFOpPS512Conversion : public ConvertToLLVMPattern {
- explicit ScaleFOpPS512Conversion(MLIRContext *context,
- LLVMTypeConverter &typeConverter)
- : ConvertToLLVMPattern(MaskScaleFOp::getOperationName(), context,
- typeConverter) {}
+struct ScaleFOpPS512Conversion : public ConvertOpToLLVMPattern<MaskScaleFOp> {
+ using ConvertOpToLLVMPattern<MaskScaleFOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(MaskScaleFOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- if (!getSrcVectorElementType(cast<MaskScaleFOp>(op)).isF32())
+ if (!getSrcVectorElementType(op).isF32())
return failure();
return matchAndRewriteOneToOne<MaskScaleFOp,
LLVM::x86_avx512_mask_scalef_ps_512>(
- *this, *getTypeConverter(), op, operands, rewriter);
+ *getTypeConverter(), op, operands, rewriter);
}
};
-struct ScaleFOpPD512Conversion : public ConvertToLLVMPattern {
- explicit ScaleFOpPD512Conversion(MLIRContext *context,
- LLVMTypeConverter &typeConverter)
- : ConvertToLLVMPattern(MaskScaleFOp::getOperationName(), context,
- typeConverter) {}
+struct ScaleFOpPD512Conversion : public ConvertOpToLLVMPattern<MaskScaleFOp> {
+ using ConvertOpToLLVMPattern<MaskScaleFOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(MaskScaleFOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- if (!getSrcVectorElementType(cast<MaskScaleFOp>(op)).isF64())
+ if (!getSrcVectorElementType(op).isF64())
return failure();
return matchAndRewriteOneToOne<MaskScaleFOp,
LLVM::x86_avx512_mask_scalef_pd_512>(
- *this, *getTypeConverter(), op, operands, rewriter);
+ *getTypeConverter(), op, operands, rewriter);
}
};
} // namespace
@@ -145,11 +134,10 @@ struct ScaleFOpPD512Conversion : public ConvertToLLVMPattern {
/// Populate the given list with patterns that convert from AVX512 to LLVM.
void mlir::populateAVX512ToLLVMConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
- MLIRContext *ctx = converter.getDialect()->getContext();
// clang-format off
patterns.insert<MaskRndScaleOpPS512Conversion,
MaskRndScaleOpPD512Conversion,
ScaleFOpPS512Conversion,
- ScaleFOpPD512Conversion>(ctx, converter);
+ ScaleFOpPD512Conversion>(converter);
// clang-format on
}
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
index 69ea393e5df1..bf17200e594f 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
@@ -18,17 +18,13 @@
namespace mlir {
template <unsigned AllocaAddrSpace>
-struct GPUFuncOpLowering : ConvertToLLVMPattern {
- explicit GPUFuncOpLowering(LLVMTypeConverter &typeConverter)
- : ConvertToLLVMPattern(gpu::GPUFuncOp::getOperationName(),
- typeConverter.getDialect()->getContext(),
- typeConverter) {}
+struct GPUFuncOpLowering : ConvertOpToLLVMPattern<gpu::GPUFuncOp> {
+ using ConvertOpToLLVMPattern<gpu::GPUFuncOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
assert(operands.empty() && "func op is not expected to have operands");
- auto gpuFuncOp = cast<gpu::GPUFuncOp>(op);
Location loc = gpuFuncOp.getLoc();
SmallVector<LLVM::GlobalOp, 3> workgroupBuffers;
@@ -154,14 +150,11 @@ struct GPUFuncOpLowering : ConvertToLLVMPattern {
}
};
-struct GPUReturnOpLowering : public ConvertToLLVMPattern {
- GPUReturnOpLowering(LLVMTypeConverter &typeConverter)
- : ConvertToLLVMPattern(gpu::ReturnOp::getOperationName(),
- typeConverter.getDialect()->getContext(),
- typeConverter) {}
+struct GPUReturnOpLowering : public ConvertOpToLLVMPattern<gpu::ReturnOp> {
+ using ConvertOpToLLVMPattern<gpu::ReturnOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(gpu::ReturnOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, operands);
return success();
diff --git a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h
index f4b7cedeb0e1..a51dff51cac4 100644
--- a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h
@@ -21,7 +21,7 @@ namespace mlir {
// `indexBitwidth`, sign-extend or truncate the resulting value to match the
// bitwidth expected by the consumers of the value.
template <typename Op, typename XOp, typename YOp, typename ZOp>
-struct GPUIndexIntrinsicOpLowering : public ConvertToLLVMPattern {
+struct GPUIndexIntrinsicOpLowering : public ConvertOpToLLVMPattern<Op> {
private:
enum dimension { X = 0, Y = 1, Z = 2, invalid };
unsigned indexBitwidth;
@@ -36,19 +36,17 @@ struct GPUIndexIntrinsicOpLowering : public ConvertToLLVMPattern {
public:
explicit GPUIndexIntrinsicOpLowering(LLVMTypeConverter &typeConverter)
- : ConvertToLLVMPattern(Op::getOperationName(),
- typeConverter.getDialect()->getContext(),
- typeConverter),
+ : ConvertOpToLLVMPattern<Op>(typeConverter),
indexBitwidth(typeConverter.getIndexTypeBitwidth()) {}
// Convert the kernel arguments to an LLVM type, preserve the rest.
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(Op op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
MLIRContext *context = rewriter.getContext();
Value newOp;
- switch (dimensionToIndex(cast<Op>(op))) {
+ switch (dimensionToIndex(op)) {
case X:
newOp = rewriter.create<XOp>(loc, LLVM::LLVMType::getInt32Ty(context));
break;
diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
index fc743823fd31..9d08aeee1906 100644
--- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
@@ -29,16 +29,15 @@ namespace mlir {
/// will be transformed into
/// llvm.call @__nv_expf(%arg_f32) : (!llvm.float) -> !llvm.float
template <typename SourceOp>
-struct OpToFuncCallLowering : public ConvertToLLVMPattern {
+struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
public:
explicit OpToFuncCallLowering(LLVMTypeConverter &lowering_, StringRef f32Func,
StringRef f64Func)
- : ConvertToLLVMPattern(SourceOp::getOperationName(),
- lowering_.getDialect()->getContext(), lowering_),
- f32Func(f32Func), f64Func(f64Func) {}
+ : ConvertOpToLLVMPattern<SourceOp>(lowering_), f32Func(f32Func),
+ f64Func(f64Func) {}
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
using LLVM::LLVMFuncOp;
using LLVM::LLVMType;
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index a0fe48175636..3e90894e2fe9 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -31,10 +31,8 @@ using namespace mlir;
namespace {
-struct GPUShuffleOpLowering : public ConvertToLLVMPattern {
- explicit GPUShuffleOpLowering(LLVMTypeConverter &lowering_)
- : ConvertToLLVMPattern(gpu::ShuffleOp::getOperationName(),
- lowering_.getDialect()->getContext(), lowering_) {}
+struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
+ using ConvertOpToLLVMPattern<gpu::ShuffleOp>::ConvertOpToLLVMPattern;
/// Lowers a shuffle to the corresponding NVVM op.
///
@@ -53,7 +51,7 @@ struct GPUShuffleOpLowering : public ConvertToLLVMPattern {
/// %shfl_pred = llvm.extractvalue %shfl[1 : index] :
/// !llvm<"{ float, i1 }">
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(gpu::ShuffleOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
gpu::ShuffleOpAdaptor adaptor(operands);
diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
index b907703995d8..47e8f27ee04a 100644
--- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
+++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
@@ -126,19 +126,17 @@ class BaseViewConversionHelper {
};
// RangeOp creates a new range descriptor.
-class RangeOpConversion : public ConvertToLLVMPattern {
+class RangeOpConversion : public ConvertOpToLLVMPattern<RangeOp> {
public:
- explicit RangeOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
- : ConvertToLLVMPattern(RangeOp::getOperationName(), context, lowering_) {}
+ using ConvertOpToLLVMPattern<RangeOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(RangeOp rangeOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto rangeOp = cast<RangeOp>(op);
auto rangeDescriptorTy = convertRangeType(
rangeOp.getType().cast<RangeType>(), *getTypeConverter());
- edsc::ScopedContext context(rewriter, op->getLoc());
+ edsc::ScopedContext context(rewriter, rangeOp->getLoc());
// Fill in an aggregate value of the descriptor.
RangeOpAdaptor adaptor(operands);
@@ -146,7 +144,7 @@ class RangeOpConversion : public ConvertToLLVMPattern {
desc = llvm_insertvalue(desc, adaptor.min(), rewriter.getI64ArrayAttr(0));
desc = llvm_insertvalue(desc, adaptor.max(), rewriter.getI64ArrayAttr(1));
desc = llvm_insertvalue(desc, adaptor.step(), rewriter.getI64ArrayAttr(2));
- rewriter.replaceOp(op, desc);
+ rewriter.replaceOp(rangeOp, desc);
return success();
}
};
@@ -154,17 +152,13 @@ class RangeOpConversion : public ConvertToLLVMPattern {
// ReshapeOp creates a new view descriptor of the proper rank.
// For now, the only conversion supported is for target MemRef with static sizes
// and strides.
-class ReshapeOpConversion : public ConvertToLLVMPattern {
+class ReshapeOpConversion : public ConvertOpToLLVMPattern<ReshapeOp> {
public:
- explicit ReshapeOpConversion(MLIRContext *context,
- LLVMTypeConverter &lowering_)
- : ConvertToLLVMPattern(ReshapeOp::getOperationName(), context,
- lowering_) {}
+ using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(ReshapeOp reshapeOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto reshapeOp = cast<ReshapeOp>(op);
MemRefType dstType = reshapeOp.getResultType();
if (!dstType.hasStaticShape())
@@ -178,7 +172,7 @@ class ReshapeOpConversion : public ConvertToLLVMPattern {
}))
return failure();
- edsc::ScopedContext context(rewriter, op->getLoc());
+ edsc::ScopedContext context(rewriter, reshapeOp->getLoc());
ReshapeOpAdaptor adaptor(operands);
BaseViewConversionHelper baseDesc(adaptor.src());
BaseViewConversionHelper desc(typeConverter->convertType(dstType));
@@ -189,7 +183,7 @@ class ReshapeOpConversion : public ConvertToLLVMPattern {
desc.setConstantSize(en.index(), en.value());
for (auto en : llvm::enumerate(strides))
desc.setConstantStride(en.index(), en.value());
- rewriter.replaceOp(op, {desc});
+ rewriter.replaceOp(reshapeOp, {desc});
return success();
}
};
@@ -200,19 +194,17 @@ class ReshapeOpConversion : public ConvertToLLVMPattern {
/// and stride corresponding to the region of memory within the bounds of
/// the parent view.
/// The linalg.slice op is replaced by the alloca'ed pointer.
-class SliceOpConversion : public ConvertToLLVMPattern {
+class SliceOpConversion : public ConvertOpToLLVMPattern<SliceOp> {
public:
- explicit SliceOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
- : ConvertToLLVMPattern(SliceOp::getOperationName(), context, lowering_) {}
+ using ConvertOpToLLVMPattern<SliceOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(SliceOp sliceOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- edsc::ScopedContext context(rewriter, op->getLoc());
+ edsc::ScopedContext context(rewriter, sliceOp->getLoc());
SliceOpAdaptor adaptor(operands);
BaseViewConversionHelper baseDesc(adaptor.view());
- auto sliceOp = cast<SliceOp>(op);
auto memRefType = sliceOp.getBaseViewType();
auto int64Ty = typeConverter->convertType(rewriter.getIntegerType(64))
.cast<LLVM::LLVMType>();
@@ -248,7 +240,7 @@ class SliceOpConversion : public ConvertToLLVMPattern {
// Corner case, no sizes or strides: early return the descriptor.
if (sliceOp.getShapedType().getRank() == 0)
- return rewriter.replaceOp(op, {desc}), success();
+ return rewriter.replaceOp(sliceOp, {desc}), success();
Value zero = llvm_constant(
int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
@@ -279,20 +271,18 @@ class SliceOpConversion : public ConvertToLLVMPattern {
}
}
- rewriter.replaceOp(op, {desc});
+ rewriter.replaceOp(sliceOp, {desc});
return success();
}
};
// YieldOp produces and LLVM::ReturnOp.
-class YieldOpConversion : public ConvertToLLVMPattern {
+class YieldOpConversion : public ConvertOpToLLVMPattern<linalg::YieldOp> {
public:
- explicit YieldOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
- : ConvertToLLVMPattern(linalg::YieldOp::getOperationName(), context,
- lowering_) {}
+ using ConvertOpToLLVMPattern<linalg::YieldOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(linalg::YieldOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, operands);
return success();
@@ -302,10 +292,9 @@ class YieldOpConversion : public ConvertToLLVMPattern {
/// Populate the given list with patterns that convert from Linalg to LLVM.
void mlir::populateLinalgToLLVMConversionPatterns(
- LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
- MLIRContext *ctx) {
+ LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
patterns.insert<RangeOpConversion, ReshapeOpConversion, SliceOpConversion,
- YieldOpConversion>(ctx, converter);
+ YieldOpConversion>(converter);
// Populate the type conversions for the linalg types.
converter.addConversion(
@@ -331,7 +320,7 @@ void ConvertLinalgToLLVMPass::runOnOperation() {
populateVectorToSCFConversionPatterns(patterns, &getContext());
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
populateVectorToLLVMConversionPatterns(converter, patterns);
- populateLinalgToLLVMConversionPatterns(converter, patterns, &getContext());
+ populateLinalgToLLVMConversionPatterns(converter, patterns);
LLVMConversionTarget target(getContext());
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
index c589ef69f2c4..e970d82c86df 100644
--- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
+++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
@@ -21,34 +21,30 @@ namespace {
/// expected to either be processed by the conversion infrastructure or already
/// contain ops compatible with LLVM dialect types.
template <typename OpType>
-struct RegionOpConversion : public ConvertToLLVMPattern {
- explicit RegionOpConversion(MLIRContext *context,
- LLVMTypeConverter &typeConverter)
- : ConvertToLLVMPattern(OpType::getOperationName(), context,
- typeConverter) {}
+struct RegionOpConversion : public ConvertOpToLLVMPattern<OpType> {
+ using ConvertOpToLLVMPattern<OpType>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(OpType curOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto curOp = cast<OpType>(op);
auto newOp = rewriter.create<OpType>(curOp.getLoc(), TypeRange(), operands,
curOp.getAttrs());
rewriter.inlineRegionBefore(curOp.region(), newOp.region(),
newOp.region().end());
- if (failed(rewriter.convertRegionTypes(&newOp.region(), *typeConverter)))
+ if (failed(rewriter.convertRegionTypes(&newOp.region(),
+ *this->getTypeConverter())))
return failure();
- rewriter.eraseOp(op);
+ rewriter.eraseOp(curOp);
return success();
}
};
} // namespace
void mlir::populateOpenMPToLLVMConversionPatterns(
- MLIRContext *context, LLVMTypeConverter &converter,
- OwningRewritePatternList &patterns) {
+ LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
patterns.insert<RegionOpConversion<omp::ParallelOp>,
- RegionOpConversion<omp::WsLoopOp>>(context, converter);
+ RegionOpConversion<omp::WsLoopOp>>(converter);
}
namespace {
@@ -60,13 +56,12 @@ struct ConvertOpenMPToLLVMPass
void ConvertOpenMPToLLVMPass::runOnOperation() {
auto module = getOperation();
- MLIRContext *context = &getContext();
// Convert to OpenMP operations with LLVM IR dialect
OwningRewritePatternList patterns;
LLVMTypeConverter converter(&getContext());
populateStdToLLVMConversionPatterns(converter, patterns);
- populateOpenMPToLLVMConversionPatterns(context, converter, patterns);
+ populateOpenMPToLLVMConversionPatterns(converter, patterns);
LLVMConversionTarget target(getContext());
target.addDynamicallyLegalOp<omp::ParallelOp, omp::WsLoopOp>(
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 85d3e2bddd66..ebe07366f6ec 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -296,39 +296,33 @@ namespace {
/// Conversion pattern for a vector.matrix_multiply.
/// This is lowered directly to the proper llvm.intr.matrix.multiply.
-class VectorMatmulOpConversion : public ConvertToLLVMPattern {
+class VectorMatmulOpConversion
+ : public ConvertOpToLLVMPattern<vector::MatmulOp> {
public:
- explicit VectorMatmulOpConversion(MLIRContext *context,
- LLVMTypeConverter &typeConverter)
- : ConvertToLLVMPattern(vector::MatmulOp::getOperationName(), context,
- typeConverter) {}
+ using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(vector::MatmulOp matmulOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto matmulOp = cast<vector::MatmulOp>(op);
auto adaptor = vector::MatmulOpAdaptor(operands);
rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
- op, typeConverter->convertType(matmulOp.res().getType()), adaptor.lhs(),
- adaptor.rhs(), matmulOp.lhs_rows(), matmulOp.lhs_columns(),
- matmulOp.rhs_columns());
+ matmulOp, typeConverter->convertType(matmulOp.res().getType()),
+ adaptor.lhs(), adaptor.rhs(), matmulOp.lhs_rows(),
+ matmulOp.lhs_columns(), matmulOp.rhs_columns());
return success();
}
};
/// Conversion pattern for a vector.flat_transpose.
/// This is lowered directly to the proper llvm.intr.matrix.transpose.
-class VectorFlatTransposeOpConversion : public ConvertToLLVMPattern {
+class VectorFlatTransposeOpConversion
+ : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> {
public:
- explicit VectorFlatTransposeOpConversion(MLIRContext *context,
- LLVMTypeConverter &typeConverter)
- : ConvertToLLVMPattern(vector::FlatTransposeOp::getOperationName(),
- context, typeConverter) {}
+ using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(vector::FlatTransposeOp transOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto transOp = cast<vector::FlatTransposeOp>(op);
auto adaptor = vector::FlatTransposeOpAdaptor(operands);
rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
transOp, typeConverter->convertType(transOp.res().getType()),
@@ -338,18 +332,15 @@ class VectorFlatTransposeOpConversion : public ConvertToLLVMPattern {
};
/// Conversion pattern for a vector.maskedload.
-class VectorMaskedLoadOpConversion : public ConvertToLLVMPattern {
+class VectorMaskedLoadOpConversion
+ : public ConvertOpToLLVMPattern<vector::MaskedLoadOp> {
public:
- explicit VectorMaskedLoadOpConversion(MLIRContext *context,
- LLVMTypeConverter &typeConverter)
- : ConvertToLLVMPattern(vector::MaskedLoadOp::getOperationName(), context,
- typeConverter) {}
+ using ConvertOpToLLVMPattern<vector::MaskedLoadOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(vector::MaskedLoadOp load, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto loc = op->getLoc();
- auto load = cast<vector::MaskedLoadOp>(op);
+ auto loc = load->getLoc();
auto adaptor = vector::MaskedLoadOpAdaptor(operands);
// Resolve alignment.
@@ -371,18 +362,15 @@ class VectorMaskedLoadOpConversion : public ConvertToLLVMPattern {
};
/// Conversion pattern for a vector.maskedstore.
-class VectorMaskedStoreOpConversion : public ConvertToLLVMPattern {
+class VectorMaskedStoreOpConversion
+ : public ConvertOpToLLVMPattern<vector::MaskedStoreOp> {
public:
- explicit VectorMaskedStoreOpConversion(MLIRContext *context,
- LLVMTypeConverter &typeConverter)
- : ConvertToLLVMPattern(vector::MaskedStoreOp::getOperationName(), context,
- typeConverter) {}
+ using ConvertOpToLLVMPattern<vector::MaskedStoreOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(vector::MaskedStoreOp store, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto loc = op->getLoc();
- auto store = cast<vector::MaskedStoreOp>(op);
+ auto loc = store->getLoc();
auto adaptor = vector::MaskedStoreOpAdaptor(operands);
// Resolve alignment.
@@ -404,18 +392,15 @@ class VectorMaskedStoreOpConversion : public ConvertToLLVMPattern {
};
/// Conversion pattern for a vector.gather.
-class VectorGatherOpConversion : public ConvertToLLVMPattern {
+class VectorGatherOpConversion
+ : public ConvertOpToLLVMPattern<vector::GatherOp> {
public:
- explicit VectorGatherOpConversion(MLIRContext *context,
- LLVMTypeConverter &typeConverter)
- : ConvertToLLVMPattern(vector::GatherOp::getOperationName(), context,
- typeConverter) {}
+ using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(vector::GatherOp gather, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto loc = op->getLoc();
- auto gather = cast<vector::GatherOp>(op);
+ auto loc = gather->getLoc();
auto adaptor = vector::GatherOpAdaptor(operands);
// Resolve alignment.
@@ -440,18 +425,15 @@ class VectorGatherOpConversion : public ConvertToLLVMPattern {
};
/// Conversion pattern for a vector.scatter.
-class VectorScatterOpConversion : public ConvertToLLVMPattern {
+class VectorScatterOpConversion
+ : public ConvertOpToLLVMPattern<vector::ScatterOp> {
public:
- explicit VectorScatterOpConversion(MLIRContext *context,
- LLVMTypeConverter &typeConverter)
- : ConvertToLLVMPattern(vector::ScatterOp::getOperationName(), context,
- typeConverter) {}
+ using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(vector::ScatterOp scatter, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto loc = op->getLoc();
- auto scatter = cast<vector::ScatterOp>(op);
+ auto loc = scatter->getLoc();
auto adaptor = vector::ScatterOpAdaptor(operands);
// Resolve alignment.
@@ -476,18 +458,15 @@ class VectorScatterOpConversion : public ConvertToLLVMPattern {
};
/// Conversion pattern for a vector.expandload.
-class VectorExpandLoadOpConversion : public ConvertToLLVMPattern {
+class VectorExpandLoadOpConversion
+ : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> {
public:
- explicit VectorExpandLoadOpConversion(MLIRContext *context,
- LLVMTypeConverter &typeConverter)
- : ConvertToLLVMPattern(vector::ExpandLoadOp::getOperationName(), context,
- typeConverter) {}
+ using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(vector::ExpandLoadOp expand, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto loc = op->getLoc();
- auto expand = cast<vector::ExpandLoadOp>(op);
+ auto loc = expand->getLoc();
auto adaptor = vector::ExpandLoadOpAdaptor(operands);
Value ptr;
@@ -497,25 +476,22 @@ class VectorExpandLoadOpConversion : public ConvertToLLVMPattern {
auto vType = expand.getResultVectorType();
rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
- op, typeConverter->convertType(vType), ptr, adaptor.mask(),
+ expand, typeConverter->convertType(vType), ptr, adaptor.mask(),
adaptor.pass_thru());
return success();
}
};
/// Conversion pattern for a vector.compressstore.
-class VectorCompressStoreOpConversion : public ConvertToLLVMPattern {
+class VectorCompressStoreOpConversion
+ : public ConvertOpToLLVMPattern<vector::CompressStoreOp> {
public:
- explicit VectorCompressStoreOpConversion(MLIRContext *context,
- LLVMTypeConverter &typeConverter)
- : ConvertToLLVMPattern(vector::CompressStoreOp::getOperationName(),
- context, typeConverter) {}
+ using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(vector::CompressStoreOp compress, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto loc = op->getLoc();
- auto compress = cast<vector::CompressStoreOp>(op);
+ auto loc = compress->getLoc();
auto adaptor = vector::CompressStoreOpAdaptor(operands);
Value ptr;
@@ -524,25 +500,23 @@ class VectorCompressStoreOpConversion : public ConvertToLLVMPattern {
return failure();
rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
- op, adaptor.value(), ptr, adaptor.mask());
+ compress, adaptor.value(), ptr, adaptor.mask());
return success();
}
};
/// Conversion pattern for all vector reductions.
-class VectorReductionOpConversion : public ConvertToLLVMPattern {
+class VectorReductionOpConversion
+ : public ConvertOpToLLVMPattern<vector::ReductionOp> {
public:
- explicit VectorReductionOpConversion(MLIRContext *context,
- LLVMTypeConverter &typeConverter,
+ explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv,
bool reassociateFPRed)
- : ConvertToLLVMPattern(vector::ReductionOp::getOperationName(), context,
- typeConverter),
+ : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv),
reassociateFPReductions(reassociateFPRed) {}
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(vector::ReductionOp reductionOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto reductionOp = cast<vector::ReductionOp>(op);
auto kind = reductionOp.kind();
Type eltType = reductionOp.dest().getType();
Type llvmType = typeConverter->convertType(eltType);
@@ -550,33 +524,33 @@ class VectorReductionOpConversion : public ConvertToLLVMPattern {
// Integer reductions: add/mul/min/max/and/or/xor.
if (kind == "add")
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>(
- op, llvmType, operands[0]);
+ reductionOp, llvmType, operands[0]);
else if (kind == "mul")
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>(
- op, llvmType, operands[0]);
+ reductionOp, llvmType, operands[0]);
else if (kind == "min" &&
(eltType.isIndex() || eltType.isUnsignedInteger()))
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>(
- op, llvmType, operands[0]);
+ reductionOp, llvmType, operands[0]);
else if (kind == "min")
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>(
- op, llvmType, operands[0]);
+ reductionOp, llvmType, operands[0]);
else if (kind == "max" &&
(eltType.isIndex() || eltType.isUnsignedInteger()))
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>(
- op, llvmType, operands[0]);
+ reductionOp, llvmType, operands[0]);
else if (kind == "max")
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>(
- op, llvmType, operands[0]);
+ reductionOp, llvmType, operands[0]);
else if (kind == "and")
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>(
- op, llvmType, operands[0]);
+ reductionOp, llvmType, operands[0]);
else if (kind == "or")
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>(
- op, llvmType, operands[0]);
+ reductionOp, llvmType, operands[0]);
else if (kind == "xor")
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>(
- op, llvmType, operands[0]);
+ reductionOp, llvmType, operands[0]);
else
return failure();
return success();
@@ -590,27 +564,27 @@ class VectorReductionOpConversion : public ConvertToLLVMPattern {
// Optional accumulator (or zero).
Value acc = operands.size() > 1 ? operands[1]
: rewriter.create<LLVM::ConstantOp>(
- op->getLoc(), llvmType,
+ reductionOp->getLoc(), llvmType,
rewriter.getZeroAttr(eltType));
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>(
- op, llvmType, acc, operands[0],
+ reductionOp, llvmType, acc, operands[0],
rewriter.getBoolAttr(reassociateFPReductions));
} else if (kind == "mul") {
// Optional accumulator (or one).
Value acc = operands.size() > 1
? operands[1]
: rewriter.create<LLVM::ConstantOp>(
- op->getLoc(), llvmType,
+ reductionOp->getLoc(), llvmType,
rewriter.getFloatAttr(eltType, 1.0));
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>(
- op, llvmType, acc, operands[0],
+ reductionOp, llvmType, acc, operands[0],
rewriter.getBoolAttr(reassociateFPReductions));
} else if (kind == "min")
- rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(op, llvmType,
- operands[0]);
+ rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(
+ reductionOp, llvmType, operands[0]);
else if (kind == "max")
- rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(op, llvmType,
- operands[0]);
+ rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(
+ reductionOp, llvmType, operands[0]);
else
return failure();
return success();
@@ -621,17 +595,16 @@ class VectorReductionOpConversion : public ConvertToLLVMPattern {
};
/// Conversion pattern for a vector.create_mask (1-D only).
-class VectorCreateMaskOpConversion : public ConvertToLLVMPattern {
+class VectorCreateMaskOpConversion
+ : public ConvertOpToLLVMPattern<vector::CreateMaskOp> {
public:
- explicit VectorCreateMaskOpConversion(MLIRContext *context,
- LLVMTypeConverter &typeConverter,
+ explicit VectorCreateMaskOpConversion(LLVMTypeConverter &typeConv,
bool enableIndexOpt)
- : ConvertToLLVMPattern(vector::CreateMaskOp::getOperationName(), context,
- typeConverter),
+ : ConvertOpToLLVMPattern<vector::CreateMaskOp>(typeConv),
enableIndexOptimizations(enableIndexOpt) {}
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(vector::CreateMaskOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto dstType = op->getResult(0).getType().cast<VectorType>();
int64_t rank = dstType.getRank();
@@ -648,19 +621,16 @@ class VectorCreateMaskOpConversion : public ConvertToLLVMPattern {
const bool enableIndexOptimizations;
};
-class VectorShuffleOpConversion : public ConvertToLLVMPattern {
+class VectorShuffleOpConversion
+ : public ConvertOpToLLVMPattern<vector::ShuffleOp> {
public:
- explicit VectorShuffleOpConversion(MLIRContext *context,
- LLVMTypeConverter &typeConverter)
- : ConvertToLLVMPattern(vector::ShuffleOp::getOperationName(), context,
- typeConverter) {}
+ using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(vector::ShuffleOp shuffleOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto loc = op->getLoc();
+ auto loc = shuffleOp->getLoc();
auto adaptor = vector::ShuffleOpAdaptor(operands);
- auto shuffleOp = cast<vector::ShuffleOp>(op);
auto v1Type = shuffleOp.getV1VectorType();
auto v2Type = shuffleOp.getV2VectorType();
auto vectorType = shuffleOp.getVectorType();
@@ -680,9 +650,9 @@ class VectorShuffleOpConversion : public ConvertToLLVMPattern {
// For rank 1, where both operands have *exactly* the same vector type,
// there is direct shuffle support in LLVM. Use it!
if (rank == 1 && v1Type == v2Type) {
- Value shuffle = rewriter.create<LLVM::ShuffleVectorOp>(
+ Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>(
loc, adaptor.v1(), adaptor.v2(), maskArrayAttr);
- rewriter.replaceOp(op, shuffle);
+ rewriter.replaceOp(shuffleOp, llvmShuffleOp);
return success();
}
@@ -701,23 +671,22 @@ class VectorShuffleOpConversion : public ConvertToLLVMPattern {
insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
llvmType, rank, insPos++);
}
- rewriter.replaceOp(op, insert);
+ rewriter.replaceOp(shuffleOp, insert);
return success();
}
};
-class VectorExtractElementOpConversion : public ConvertToLLVMPattern {
+class VectorExtractElementOpConversion
+ : public ConvertOpToLLVMPattern<vector::ExtractElementOp> {
public:
- explicit VectorExtractElementOpConversion(MLIRContext *context,
- LLVMTypeConverter &typeConverter)
- : ConvertToLLVMPattern(vector::ExtractElementOp::getOperationName(),
- context, typeConverter) {}
+ using ConvertOpToLLVMPattern<
+ vector::ExtractElementOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(vector::ExtractElementOp extractEltOp,
+ ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto adaptor = vector::ExtractElementOpAdaptor(operands);
- auto extractEltOp = cast<vector::ExtractElementOp>(op);
auto vectorType = extractEltOp.getVectorType();
auto llvmType = typeConverter->convertType(vectorType.getElementType());
@@ -726,24 +695,21 @@ class VectorExtractElementOpConversion : public ConvertToLLVMPattern {
return failure();
rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
- op, llvmType, adaptor.vector(), adaptor.position());
+ extractEltOp, llvmType, adaptor.vector(), adaptor.position());
return success();
}
};
-class VectorExtractOpConversion : public ConvertToLLVMPattern {
+class VectorExtractOpConversion
+ : public ConvertOpToLLVMPattern<vector::ExtractOp> {
public:
- explicit VectorExtractOpConversion(MLIRContext *context,
- LLVMTypeConverter &typeConverter)
- : ConvertToLLVMPattern(vector::ExtractOp::getOperationName(), context,
- typeConverter) {}
+ using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(vector::ExtractOp extractOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto loc = op->getLoc();
+ auto loc = extractOp->getLoc();
auto adaptor = vector::ExtractOpAdaptor(operands);
- auto extractOp = cast<vector::ExtractOp>(op);
auto vectorType = extractOp.getVectorType();
auto resultType = extractOp.getResult().getType();
auto llvmResultType = typeConverter->convertType(resultType);
@@ -757,12 +723,12 @@ class VectorExtractOpConversion : public ConvertToLLVMPattern {
if (resultType.isa<VectorType>()) {
Value extracted = rewriter.create<LLVM::ExtractValueOp>(
loc, llvmResultType, adaptor.vector(), positionArrayAttr);
- rewriter.replaceOp(op, extracted);
+ rewriter.replaceOp(extractOp, extracted);
return success();
}
// Potential extraction of 1-D vector from array.
- auto *context = op->getContext();
+ auto *context = extractOp->getContext();
Value extracted = adaptor.vector();
auto positionAttrs = positionArrayAttr.getValue();
if (positionAttrs.size() > 1) {
@@ -780,7 +746,7 @@ class VectorExtractOpConversion : public ConvertToLLVMPattern {
auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
extracted =
rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
- rewriter.replaceOp(op, extracted);
+ rewriter.replaceOp(extractOp, extracted);
return success();
}
@@ -800,39 +766,32 @@ class VectorExtractOpConversion : public ConvertToLLVMPattern {
/// (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">)
/// -> !llvm<"<8 x float>">
/// ```
-class VectorFMAOp1DConversion : public ConvertToLLVMPattern {
+class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
public:
- explicit VectorFMAOp1DConversion(MLIRContext *context,
- LLVMTypeConverter &typeConverter)
- : ConvertToLLVMPattern(vector::FMAOp::getOperationName(), context,
- typeConverter) {}
+ using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(vector::FMAOp fmaOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto adaptor = vector::FMAOpAdaptor(operands);
- vector::FMAOp fmaOp = cast<vector::FMAOp>(op);
VectorType vType = fmaOp.getVectorType();
if (vType.getRank() != 1)
return failure();
- rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(op, adaptor.lhs(),
+ rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(fmaOp, adaptor.lhs(),
adaptor.rhs(), adaptor.acc());
return success();
}
};
-class VectorInsertElementOpConversion : public ConvertToLLVMPattern {
+class VectorInsertElementOpConversion
+ : public ConvertOpToLLVMPattern<vector::InsertElementOp> {
public:
- explicit VectorInsertElementOpConversion(MLIRContext *context,
- LLVMTypeConverter &typeConverter)
- : ConvertToLLVMPattern(vector::InsertElementOp::getOperationName(),
- context, typeConverter) {}
+ using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(vector::InsertElementOp insertEltOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto adaptor = vector::InsertElementOpAdaptor(operands);
- auto insertEltOp = cast<vector::InsertElementOp>(op);
auto vectorType = insertEltOp.getDestVectorType();
auto llvmType = typeConverter->convertType(vectorType);
@@ -841,24 +800,22 @@ class VectorInsertElementOpConversion : public ConvertToLLVMPattern {
return failure();
rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
- op, llvmType, adaptor.dest(), adaptor.source(), adaptor.position());
+ insertEltOp, llvmType, adaptor.dest(), adaptor.source(),
+ adaptor.position());
return success();
}
};
-class VectorInsertOpConversion : public ConvertToLLVMPattern {
+class VectorInsertOpConversion
+ : public ConvertOpToLLVMPattern<vector::InsertOp> {
public:
- explicit VectorInsertOpConversion(MLIRContext *context,
- LLVMTypeConverter &typeConverter)
- : ConvertToLLVMPattern(vector::InsertOp::getOperationName(), context,
- typeConverter) {}
+ using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(vector::InsertOp insertOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto loc = op->getLoc();
+ auto loc = insertOp->getLoc();
auto adaptor = vector::InsertOpAdaptor(operands);
- auto insertOp = cast<vector::InsertOp>(op);
auto sourceType = insertOp.getSourceType();
auto destVectorType = insertOp.getDestVectorType();
auto llvmResultType = typeConverter->convertType(destVectorType);
@@ -873,12 +830,12 @@ class VectorInsertOpConversion : public ConvertToLLVMPattern {
Value inserted = rewriter.create<LLVM::InsertValueOp>(
loc, llvmResultType, adaptor.dest(), adaptor.source(),
positionArrayAttr);
- rewriter.replaceOp(op, inserted);
+ rewriter.replaceOp(insertOp, inserted);
return success();
}
// Potential extraction of 1-D vector from array.
- auto *context = op->getContext();
+ auto *context = insertOp->getContext();
Value extracted = adaptor.dest();
auto positionAttrs = positionArrayAttr.getValue();
auto position = positionAttrs.back().cast<IntegerAttr>();
@@ -908,7 +865,7 @@ class VectorInsertOpConversion : public ConvertToLLVMPattern {
nMinusOnePositionAttrs);
}
- rewriter.replaceOp(op, inserted);
+ rewriter.replaceOp(insertOp, inserted);
return success();
}
};
@@ -1117,18 +1074,15 @@ computeContiguousStrides(MemRefType memRefType) {
return strides;
}
-class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
+class VectorTypeCastOpConversion
+ : public ConvertOpToLLVMPattern<vector::TypeCastOp> {
public:
- explicit VectorTypeCastOpConversion(MLIRContext *context,
- LLVMTypeConverter &typeConverter)
- : ConvertToLLVMPattern(vector::TypeCastOp::getOperationName(), context,
- typeConverter) {}
+ using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(vector::TypeCastOp castOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto loc = op->getLoc();
- vector::TypeCastOp castOp = cast<vector::TypeCastOp>(op);
+ auto loc = castOp->getLoc();
MemRefType sourceMemRefType =
castOp.getOperand().getType().cast<MemRefType>();
MemRefType targetMemRefType =
@@ -1195,7 +1149,7 @@ class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
desc.setStride(rewriter, loc, index, stride);
}
- rewriter.replaceOp(op, {desc});
+ rewriter.replaceOp(castOp, {desc});
return success();
}
};
@@ -1208,18 +1162,16 @@ class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
/// 4. Create a mask where offsetVector is compared against memref upper bound.
/// 5. Rewrite op as a masked read or write.
template <typename ConcreteOp>
-class VectorTransferConversion : public ConvertToLLVMPattern {
+class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
public:
- explicit VectorTransferConversion(MLIRContext *context,
- LLVMTypeConverter &typeConv,
+ explicit VectorTransferConversion(LLVMTypeConverter &typeConv,
bool enableIndexOpt)
- : ConvertToLLVMPattern(ConcreteOp::getOperationName(), context, typeConv),
+ : ConvertOpToLLVMPattern<ConcreteOp>(typeConv),
enableIndexOptimizations(enableIndexOpt) {}
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(ConcreteOp xferOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto xferOp = cast<ConcreteOp>(op);
auto adaptor = getTransferOpAdapter(xferOp, operands);
if (xferOp.getVectorType().getRank() > 1 ||
@@ -1228,16 +1180,18 @@ class VectorTransferConversion : public ConvertToLLVMPattern {
if (xferOp.permutation_map() !=
AffineMap::getMinorIdentityMap(xferOp.permutation_map().getNumInputs(),
xferOp.getVectorType().getRank(),
- op->getContext()))
+ xferOp->getContext()))
return failure();
// Only contiguous source tensors supported atm.
auto strides = computeContiguousStrides(xferOp.getMemRefType());
if (!strides)
return failure();
- auto toLLVMTy = [&](Type t) { return typeConverter->convertType(t); };
+ auto toLLVMTy = [&](Type t) {
+ return this->getTypeConverter()->convertType(t);
+ };
- Location loc = op->getLoc();
+ Location loc = xferOp->getLoc();
MemRefType memRefType = xferOp.getMemRefType();
if (auto memrefVectorElementType =
@@ -1267,8 +1221,8 @@ class VectorTransferConversion : public ConvertToLLVMPattern {
// addrspacecast shall be used when source/dst memrefs are not on
// address space 0.
// TODO: support alignment when possible.
- Value dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(),
- adaptor.indices(), rewriter);
+ Value dataPtr = this->getStridedElementPtr(
+ loc, memRefType, adaptor.memref(), adaptor.indices(), rewriter);
auto vecTy =
toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
Value vectorDataPtr;
@@ -1280,8 +1234,9 @@ class VectorTransferConversion : public ConvertToLLVMPattern {
loc, vecTy.getPointerTo(), dataPtr);
if (!xferOp.isMaskedDim(0))
- return replaceTransferOpWithLoadOrStore(
- rewriter, *getTypeConverter(), loc, xferOp, operands, vectorDataPtr);
+ return replaceTransferOpWithLoadOrStore(rewriter,
+ *this->getTypeConverter(), loc,
+ xferOp, operands, vectorDataPtr);
// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
// 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
@@ -1294,11 +1249,11 @@ class VectorTransferConversion : public ConvertToLLVMPattern {
unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
Value off = xferOp.indices()[lastIndex];
Value dim = rewriter.create<DimOp>(loc, xferOp.memref(), lastIndex);
- Value mask = buildVectorComparison(rewriter, op, enableIndexOptimizations,
- vecWidth, dim, &off);
+ Value mask = buildVectorComparison(
+ rewriter, xferOp, enableIndexOptimizations, vecWidth, dim, &off);
// 5. Rewrite as a masked read / write.
- return replaceTransferOpWithMasked(rewriter, *getTypeConverter(), loc,
+ return replaceTransferOpWithMasked(rewriter, *this->getTypeConverter(), loc,
xferOp, operands, vectorDataPtr, mask);
}
@@ -1306,12 +1261,9 @@ class VectorTransferConversion : public ConvertToLLVMPattern {
const bool enableIndexOptimizations;
};
-class VectorPrintOpConversion : public ConvertToLLVMPattern {
+class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
public:
- explicit VectorPrintOpConversion(MLIRContext *context,
- LLVMTypeConverter &typeConverter)
- : ConvertToLLVMPattern(vector::PrintOp::getOperationName(), context,
- typeConverter) {}
+ using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern;
// Proof-of-concept lowering implementation that relies on a small
// runtime support library, which only needs to provide a few
@@ -1326,9 +1278,8 @@ class VectorPrintOpConversion : public ConvertToLLVMPattern {
// TODO: rely solely on libc in future? something else?
//
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(vector::PrintOp printOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto printOp = cast<vector::PrintOp>(op);
auto adaptor = vector::PrintOpAdaptor(operands);
Type printType = printOp.getPrintType();
@@ -1341,11 +1292,11 @@ class VectorPrintOpConversion : public ConvertToLLVMPattern {
Type eltType = vectorType ? vectorType.getElementType() : printType;
Operation *printer;
if (eltType.isF32()) {
- printer = getPrintFloat(op);
+ printer = getPrintFloat(printOp);
} else if (eltType.isF64()) {
- printer = getPrintDouble(op);
+ printer = getPrintDouble(printOp);
} else if (eltType.isIndex()) {
- printer = getPrintU64(op);
+ printer = getPrintU64(printOp);
} else if (auto intTy = eltType.dyn_cast<IntegerType>()) {
// Integers need a zero or sign extension on the operand
// (depending on the source type) as well as a signed or
@@ -1355,7 +1306,7 @@ class VectorPrintOpConversion : public ConvertToLLVMPattern {
if (width <= 64) {
if (width < 64)
conversion = PrintConversion::ZeroExt64;
- printer = getPrintU64(op);
+ printer = getPrintU64(printOp);
} else {
return failure();
}
@@ -1368,7 +1319,7 @@ class VectorPrintOpConversion : public ConvertToLLVMPattern {
conversion = PrintConversion::ZeroExt64;
else if (width < 64)
conversion = PrintConversion::SignExt64;
- printer = getPrintI64(op);
+ printer = getPrintI64(printOp);
} else {
return failure();
}
@@ -1379,10 +1330,10 @@ class VectorPrintOpConversion : public ConvertToLLVMPattern {
// Unroll vector into elementary print calls.
int64_t rank = vectorType ? vectorType.getRank() : 0;
- emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank,
+ emitRanks(rewriter, printOp, adaptor.source(), vectorType, printer, rank,
conversion);
- emitCall(rewriter, op->getLoc(), getPrintNewline(op));
- rewriter.eraseOp(op);
+ emitCall(rewriter, printOp->getLoc(), getPrintNewline(printOp));
+ rewriter.eraseOp(printOp);
return success();
}
@@ -1560,11 +1511,11 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorInsertStridedSliceOpSameRankRewritePattern,
VectorExtractStridedSliceOpConversion>(ctx);
patterns.insert<VectorReductionOpConversion>(
- ctx, converter, reassociateFPReductions);
+ converter, reassociateFPReductions);
patterns.insert<VectorCreateMaskOpConversion,
VectorTransferConversion<TransferReadOp>,
VectorTransferConversion<TransferWriteOp>>(
- ctx, converter, enableIndexOptimizations);
+ converter, enableIndexOptimizations);
patterns
.insert<VectorShuffleOpConversion,
VectorExtractElementOpConversion,
@@ -1579,13 +1530,12 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorGatherOpConversion,
VectorScatterOpConversion,
VectorExpandLoadOpConversion,
- VectorCompressStoreOpConversion>(ctx, converter);
+ VectorCompressStoreOpConversion>(converter);
// clang-format on
}
void mlir::populateVectorToLLVMMatrixConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
- MLIRContext *ctx = converter.getDialect()->getContext();
- patterns.insert<VectorMatmulOpConversion>(ctx, converter);
- patterns.insert<VectorFlatTransposeOpConversion>(ctx, converter);
+ patterns.insert<VectorMatmulOpConversion>(converter);
+ patterns.insert<VectorFlatTransposeOpConversion>(converter);
}
diff --git a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
index 61f094746a0a..e5474abfd3e3 100644
--- a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
+++ b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
@@ -55,17 +55,13 @@ namespace {
/// types. For unsupported cases, they will fall back to the vector to
/// llvm conversion pattern.
template <typename ConcreteOp>
-class VectorTransferConversion : public ConvertToLLVMPattern {
+class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
public:
- explicit VectorTransferConversion(MLIRContext *context,
- LLVMTypeConverter &typeConv)
- : ConvertToLLVMPattern(ConcreteOp::getOperationName(), context,
- typeConv) {}
+ using ConvertOpToLLVMPattern<ConcreteOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(ConcreteOp xferOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto xferOp = cast<ConcreteOp>(op);
typename ConcreteOp::Adaptor adaptor(operands);
if (xferOp.getVectorType().getRank() > 1 ||
@@ -79,11 +75,13 @@ class VectorTransferConversion : public ConvertToLLVMPattern {
if (!xferOp.isMaskedDim(0))
return failure();
- auto toLLVMTy = [&](Type t) { return typeConverter->convertType(t); };
+ auto toLLVMTy = [&](Type t) {
+ return this->getTypeConverter()->convertType(t);
+ };
LLVM::LLVMType vecTy =
toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
unsigned vecWidth = vecTy.getVectorNumElements();
- Location loc = op->getLoc();
+ Location loc = xferOp->getLoc();
// The backend result vector scalarization have trouble scalarize
// <1 x ty> result, exclude the x1 width from the lowering.
@@ -102,8 +100,8 @@ class VectorTransferConversion : public ConvertToLLVMPattern {
// Note that the dataPtr starts at the offset address specified by
// indices, so no need to calculate offset size in bytes again in
// the MUBUF instruction.
- Value dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(),
- adaptor.indices(), rewriter);
+ Value dataPtr = this->getStridedElementPtr(
+ loc, memRefType, adaptor.memref(), adaptor.indices(), rewriter);
// 1. Create and fill a <4 x i32> dwordConfig with:
// 1st two elements holding the address of dataPtr.
@@ -126,7 +124,7 @@ class VectorTransferConversion : public ConvertToLLVMPattern {
constConfig);
Value dataPtrAsI64 = rewriter.create<LLVM::PtrToIntOp>(
loc, toLLVMTy(i64Ty).template cast<LLVM::LLVMType>(), dataPtr);
- Value zero = createIndexConstant(rewriter, loc, 0);
+ Value zero = this->createIndexConstant(rewriter, loc, 0);
Value dwordConfig = rewriter.create<LLVM::InsertElementOp>(
loc,
LLVM::LLVMType::getVectorTy(
@@ -143,7 +141,7 @@ class VectorTransferConversion : public ConvertToLLVMPattern {
loc, toLLVMTy(i32Ty),
rewriter.getIntegerAttr(rewriter.getIntegerType(32), 0));
return replaceTransferOpWithMubuf(
- rewriter, operands, *getTypeConverter(), loc, xferOp, vecTy,
+ rewriter, operands, *this->getTypeConverter(), loc, xferOp, vecTy,
dwordConfig, int32Zero, int32Zero, int1False, int1False);
}
};
@@ -151,9 +149,8 @@ class VectorTransferConversion : public ConvertToLLVMPattern {
void mlir::populateVectorToROCDLConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
- MLIRContext *ctx = converter.getDialect()->getContext();
patterns.insert<VectorTransferConversion<TransferReadOp>,
- VectorTransferConversion<TransferWriteOp>>(ctx, converter);
+ VectorTransferConversion<TransferWriteOp>>(converter);
}
namespace {
More information about the Mlir-commits
mailing list