[Mlir-commits] [mlir] 04ed07b - [mlir] StandardToLLVM: clean up conversion patterns for vector operations
Alex Zinenko
llvmlistbot at llvm.org
Thu Mar 26 10:24:39 PDT 2020
Author: Alex Zinenko
Date: 2020-03-26T18:24:10+01:00
New Revision: 04ed07bc174149d61c8a4ed131f0838578bdcaa5
URL: https://github.com/llvm/llvm-project/commit/04ed07bc174149d61c8a4ed131f0838578bdcaa5
DIFF: https://github.com/llvm/llvm-project/commit/04ed07bc174149d61c8a4ed131f0838578bdcaa5.diff
LOG: [mlir] StandardToLLVM: clean up conversion patterns for vector operations
Summary:
Provide a public VectorConvertToLLVMPattern utility class to implement
conversions with automatic unrolling of operation on multidimensional vectors
to lists of operations on single-dimensional vectors when lowering to the LLVM
dialect. Drop the template-based check on the number of operands since the
actual implementation does not depend on the operand number anymore. This check
only creates spurious concepts (UnaryOpLowering, BinaryOpLowering, etc).
Differential Revision: https://reviews.llvm.org/D76865
Added:
Modified:
mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
index 95da9805606b..d2c7d9fb2abd 100644
--- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
+++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
@@ -416,6 +416,11 @@ LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp,
ValueRange operands,
LLVMTypeConverter &typeConverter,
ConversionPatternRewriter &rewriter);
+
+LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp,
+ ValueRange operands,
+ LLVMTypeConverter &typeConverter,
+ ConversionPatternRewriter &rewriter);
} // namespace detail
} // namespace LLVM
@@ -441,6 +446,29 @@ class OneToOneConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
}
};
+/// Basic lowering implementation for rewriting from Ops to LLVM Dialect Ops
+/// with one result. This supports higher-dimensional vector types.
+template <typename SourceOp, typename TargetOp>
+class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
+public:
+ using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
+ using Super = VectorConvertToLLVMPattern<SourceOp, TargetOp>;
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ static_assert(
+ std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
+ "expected single result op");
+ static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
+ SourceOp>::value,
+ "expected same operands and result type");
+ return LLVM::detail::vectorOneToOneRewrite(op, TargetOp::getOperationName(),
+ operands, this->typeConverter,
+ rewriter);
+ }
+};
+
/// Derived class that automatically populates legalization information for
///
diff erent LLVM ops.
class LLVMConversionTarget : public ConversionTarget {
diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 474a4f08b9f6..8bc27ab3340e 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -1148,9 +1148,10 @@ template <typename SourceOp, unsigned OpCount>
void ValidateOpCount() {
OpCountValidator<SourceOp, OpCount>();
}
+} // namespace
-static LogicalResult HandleMultidimensionalVectors(
- Operation *op, ArrayRef<Value> operands, LLVMTypeConverter &typeConverter,
+static LogicalResult handleMultidimensionalVectors(
+ Operation *op, ValueRange operands, LLVMTypeConverter &typeConverter,
std::function<Value(LLVM::LLVMType, ValueRange)> createOperand,
ConversionPatternRewriter &rewriter) {
auto vectorType = op->getResult(0).getType().dyn_cast<VectorType>();
@@ -1179,139 +1180,125 @@ static LogicalResult HandleMultidimensionalVectors(
return success();
}
-// Basic lowering implementation for rewriting from Standard Ops to LLVM Dialect
-// Ops for N-ary ops with one result. This supports higher-dimensional vector
-// types.
-template <typename SourceOp, typename TargetOp, unsigned OpCount>
-struct NaryOpLLVMOpLowering : public ConvertOpToLLVMPattern<SourceOp> {
- using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
- using Super = NaryOpLLVMOpLowering<SourceOp, TargetOp, OpCount>;
-
- // Convert the type of the result to an LLVM type, pass operands as is,
- // preserve attributes.
- LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- ValidateOpCount<SourceOp, OpCount>();
- static_assert(
- std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
- "expected single result op");
- static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
- SourceOp>::value,
- "expected same operands and result type");
-
- // Cannot convert ops if their operands are not of LLVM type.
- for (Value operand : operands) {
- if (!operand || !operand.getType().isa<LLVM::LLVMType>())
- return failure();
- }
+LogicalResult LLVM::detail::vectorOneToOneRewrite(
+ Operation *op, StringRef targetOp, ValueRange operands,
+ LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) {
+ assert(!operands.empty());
- auto llvmArrayTy = operands[0].getType().cast<LLVM::LLVMType>();
+ // Cannot convert ops if their operands are not of LLVM type.
+ if (!llvm::all_of(operands.getTypes(),
+ [](Type t) { return t.isa<LLVM::LLVMType>(); }))
+ return failure();
- if (!llvmArrayTy.isArrayTy()) {
- auto newOp = rewriter.create<TargetOp>(
- op->getLoc(), operands[0].getType(), operands, op->getAttrs());
- rewriter.replaceOp(op, newOp.getResult());
- return success();
- }
+ auto llvmArrayTy = operands[0].getType().cast<LLVM::LLVMType>();
+ if (!llvmArrayTy.isArrayTy())
+ return oneToOneRewrite(op, targetOp, operands, typeConverter, rewriter);
- if (succeeded(HandleMultidimensionalVectors(
- op, operands, this->typeConverter,
- [&](LLVM::LLVMType llvmVectorTy, ValueRange operands) {
- return rewriter.create<TargetOp>(op->getLoc(), llvmVectorTy,
- operands, op->getAttrs());
- },
- rewriter)))
- return success();
- return failure();
- }
-};
+ auto callback = [op, targetOp, &rewriter](LLVM::LLVMType llvmVectorTy,
+ ValueRange operands) {
+ OperationState state(op->getLoc(), targetOp);
+ state.addTypes(llvmVectorTy);
+ state.addOperands(operands);
+ state.addAttributes(op->getAttrs());
+ return rewriter.createOperation(state)->getResult(0);
+ };
-template <typename SourceOp, typename TargetOp>
-using UnaryOpLLVMOpLowering = NaryOpLLVMOpLowering<SourceOp, TargetOp, 1>;
-template <typename SourceOp, typename TargetOp>
-using BinaryOpLLVMOpLowering = NaryOpLLVMOpLowering<SourceOp, TargetOp, 2>;
+ return handleMultidimensionalVectors(op, operands, typeConverter, callback,
+ rewriter);
+}
+namespace {
// Specific lowerings.
// FIXME: this should be tablegen'ed.
-struct AbsFOpLowering : public UnaryOpLLVMOpLowering<AbsFOp, LLVM::FAbsOp> {
+struct AbsFOpLowering
+ : public VectorConvertToLLVMPattern<AbsFOp, LLVM::FAbsOp> {
using Super::Super;
};
-struct CeilFOpLowering : public UnaryOpLLVMOpLowering<CeilFOp, LLVM::FCeilOp> {
+struct CeilFOpLowering
+ : public VectorConvertToLLVMPattern<CeilFOp, LLVM::FCeilOp> {
using Super::Super;
};
-struct CosOpLowering : public UnaryOpLLVMOpLowering<CosOp, LLVM::CosOp> {
+struct CosOpLowering : public VectorConvertToLLVMPattern<CosOp, LLVM::CosOp> {
using Super::Super;
};
-struct ExpOpLowering : public UnaryOpLLVMOpLowering<ExpOp, LLVM::ExpOp> {
+struct ExpOpLowering : public VectorConvertToLLVMPattern<ExpOp, LLVM::ExpOp> {
using Super::Super;
};
-struct LogOpLowering : public UnaryOpLLVMOpLowering<LogOp, LLVM::LogOp> {
+struct LogOpLowering : public VectorConvertToLLVMPattern<LogOp, LLVM::LogOp> {
using Super::Super;
};
-struct Log10OpLowering : public UnaryOpLLVMOpLowering<Log10Op, LLVM::Log10Op> {
+struct Log10OpLowering
+ : public VectorConvertToLLVMPattern<Log10Op, LLVM::Log10Op> {
using Super::Super;
};
-struct Log2OpLowering : public UnaryOpLLVMOpLowering<Log2Op, LLVM::Log2Op> {
+struct Log2OpLowering
+ : public VectorConvertToLLVMPattern<Log2Op, LLVM::Log2Op> {
using Super::Super;
};
-struct NegFOpLowering : public UnaryOpLLVMOpLowering<NegFOp, LLVM::FNegOp> {
+struct NegFOpLowering
+ : public VectorConvertToLLVMPattern<NegFOp, LLVM::FNegOp> {
using Super::Super;
};
-struct AddIOpLowering : public BinaryOpLLVMOpLowering<AddIOp, LLVM::AddOp> {
+struct AddIOpLowering : public VectorConvertToLLVMPattern<AddIOp, LLVM::AddOp> {
using Super::Super;
};
-struct SubIOpLowering : public BinaryOpLLVMOpLowering<SubIOp, LLVM::SubOp> {
+struct SubIOpLowering : public VectorConvertToLLVMPattern<SubIOp, LLVM::SubOp> {
using Super::Super;
};
-struct MulIOpLowering : public BinaryOpLLVMOpLowering<MulIOp, LLVM::MulOp> {
+struct MulIOpLowering : public VectorConvertToLLVMPattern<MulIOp, LLVM::MulOp> {
using Super::Super;
};
struct SignedDivIOpLowering
- : public BinaryOpLLVMOpLowering<SignedDivIOp, LLVM::SDivOp> {
+ : public VectorConvertToLLVMPattern<SignedDivIOp, LLVM::SDivOp> {
using Super::Super;
};
-struct SqrtOpLowering : public UnaryOpLLVMOpLowering<SqrtOp, LLVM::SqrtOp> {
+struct SqrtOpLowering
+ : public VectorConvertToLLVMPattern<SqrtOp, LLVM::SqrtOp> {
using Super::Super;
};
struct UnsignedDivIOpLowering
- : public BinaryOpLLVMOpLowering<UnsignedDivIOp, LLVM::UDivOp> {
+ : public VectorConvertToLLVMPattern<UnsignedDivIOp, LLVM::UDivOp> {
using Super::Super;
};
struct SignedRemIOpLowering
- : public BinaryOpLLVMOpLowering<SignedRemIOp, LLVM::SRemOp> {
+ : public VectorConvertToLLVMPattern<SignedRemIOp, LLVM::SRemOp> {
using Super::Super;
};
struct UnsignedRemIOpLowering
- : public BinaryOpLLVMOpLowering<UnsignedRemIOp, LLVM::URemOp> {
+ : public VectorConvertToLLVMPattern<UnsignedRemIOp, LLVM::URemOp> {
using Super::Super;
};
-struct AndOpLowering : public BinaryOpLLVMOpLowering<AndOp, LLVM::AndOp> {
+struct AndOpLowering : public VectorConvertToLLVMPattern<AndOp, LLVM::AndOp> {
using Super::Super;
};
-struct OrOpLowering : public BinaryOpLLVMOpLowering<OrOp, LLVM::OrOp> {
+struct OrOpLowering : public VectorConvertToLLVMPattern<OrOp, LLVM::OrOp> {
using Super::Super;
};
-struct XOrOpLowering : public BinaryOpLLVMOpLowering<XOrOp, LLVM::XOrOp> {
+struct XOrOpLowering : public VectorConvertToLLVMPattern<XOrOp, LLVM::XOrOp> {
using Super::Super;
};
-struct AddFOpLowering : public BinaryOpLLVMOpLowering<AddFOp, LLVM::FAddOp> {
+struct AddFOpLowering
+ : public VectorConvertToLLVMPattern<AddFOp, LLVM::FAddOp> {
using Super::Super;
};
-struct SubFOpLowering : public BinaryOpLLVMOpLowering<SubFOp, LLVM::FSubOp> {
+struct SubFOpLowering
+ : public VectorConvertToLLVMPattern<SubFOp, LLVM::FSubOp> {
using Super::Super;
};
-struct MulFOpLowering : public BinaryOpLLVMOpLowering<MulFOp, LLVM::FMulOp> {
+struct MulFOpLowering
+ : public VectorConvertToLLVMPattern<MulFOp, LLVM::FMulOp> {
using Super::Super;
};
-struct DivFOpLowering : public BinaryOpLLVMOpLowering<DivFOp, LLVM::FDivOp> {
+struct DivFOpLowering
+ : public VectorConvertToLLVMPattern<DivFOp, LLVM::FDivOp> {
using Super::Super;
};
-struct RemFOpLowering : public BinaryOpLLVMOpLowering<RemFOp, LLVM::FRemOp> {
+struct RemFOpLowering
+ : public VectorConvertToLLVMPattern<RemFOp, LLVM::FRemOp> {
using Super::Super;
};
struct CopySignOpLowering
- : public BinaryOpLLVMOpLowering<CopySignOp, LLVM::CopySignOp> {
+ : public VectorConvertToLLVMPattern<CopySignOp, LLVM::CopySignOp> {
using Super::Super;
};
struct SelectOpLowering
@@ -1695,24 +1682,21 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<RsqrtOp> {
if (!vectorType)
return failure();
- if (succeeded(HandleMultidimensionalVectors(
- op, operands, typeConverter,
- [&](LLVM::LLVMType llvmVectorTy, ValueRange operands) {
- auto splatAttr = SplatElementsAttr::get(
- mlir::VectorType::get({llvmVectorTy.getUnderlyingType()
- ->getVectorNumElements()},
- floatType),
- floatOne);
- auto one = rewriter.create<LLVM::ConstantOp>(loc, llvmVectorTy,
- splatAttr);
- auto sqrt =
- rewriter.create<LLVM::SqrtOp>(loc, llvmVectorTy, operands[0]);
- return rewriter.create<LLVM::FDivOp>(loc, llvmVectorTy, one,
- sqrt);
- },
- rewriter)))
- return success();
- return failure();
+ return handleMultidimensionalVectors(
+ op, operands, typeConverter,
+ [&](LLVM::LLVMType llvmVectorTy, ValueRange operands) {
+ auto splatAttr = SplatElementsAttr::get(
+ mlir::VectorType::get(
+ {llvmVectorTy.getUnderlyingType()->getVectorNumElements()},
+ floatType),
+ floatOne);
+ auto one =
+ rewriter.create<LLVM::ConstantOp>(loc, llvmVectorTy, splatAttr);
+ auto sqrt =
+ rewriter.create<LLVM::SqrtOp>(loc, llvmVectorTy, operands[0]);
+ return rewriter.create<LLVM::FDivOp>(loc, llvmVectorTy, one, sqrt);
+ },
+ rewriter);
}
};
More information about the Mlir-commits
mailing list