[Mlir-commits] [mlir] 86306df - Extract common code to deal with multidimensional vectors.
Adrian Kuegel
llvmlistbot at llvm.org
Fri Mar 6 04:55:30 PST 2020
Author: Adrian Kuegel
Date: 2020-03-06T13:54:54+01:00
New Revision: 86306df7dd2a8e60d88c6306956080b53ac95589
URL: https://github.com/llvm/llvm-project/commit/86306df7dd2a8e60d88c6306956080b53ac95589
DIFF: https://github.com/llvm/llvm-project/commit/86306df7dd2a8e60d88c6306956080b53ac95589.diff
LOG: Extract common code to deal with multidimensional vectors.
Summary: Also replace dyn_cast_or_null with dyn_cast when possible.
Differential Revision: https://reviews.llvm.org/D75733
Added:
Modified:
mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
index e38322943cfe..c7c479bf6779 100644
--- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
@@ -24,6 +24,7 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/Functional.h"
+#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/Utils.h"
@@ -32,6 +33,7 @@
#include "llvm/IR/Type.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FormatVariadic.h"
+#include <functional>
using namespace mlir;
@@ -1165,6 +1167,36 @@ void ValidateOpCount() {
OpCountValidator<SourceOp, OpCount>();
}
+static LogicalResult HandleMultidimensionalVectors(
+ Operation *op, ArrayRef<Value> operands, LLVMTypeConverter &typeConverter,
+ std::function<Value(LLVM::LLVMType, ValueRange)> createOperand,
+ ConversionPatternRewriter &rewriter) {
+ auto vectorType = op->getResult(0).getType().dyn_cast<VectorType>();
+ if (!vectorType)
+ return failure();
+ auto vectorTypeInfo = extractNDVectorTypeInfo(vectorType, typeConverter);
+ auto llvmVectorTy = vectorTypeInfo.llvmVectorTy;
+ auto llvmArrayTy = operands[0].getType().cast<LLVM::LLVMType>();
+ if (!llvmVectorTy || llvmArrayTy != vectorTypeInfo.llvmArrayTy)
+ return failure();
+
+ auto loc = op->getLoc();
+ Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayTy);
+ nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) {
+ // For this unrolled `position` corresponding to the `linearIndex`^th
+ // element, extract operand vectors
+ SmallVector<Value, 4> extractedOperands;
+ for (auto operand : operands)
+ extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>(
+ loc, llvmVectorTy, operand, position));
+ Value newVal = createOperand(llvmVectorTy, extractedOperands);
+ desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayTy, desc, newVal,
+ position);
+ });
+ rewriter.replaceOp(op, desc);
+ return success();
+}
+
// Basic lowering implementation for rewriting from Standard Ops to LLVM Dialect
// Ops for N-ary ops with one result. This supports higher-dimensional vector
// types.
@@ -1192,7 +1224,6 @@ struct NaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
return this->matchFailure();
}
- auto loc = op->getLoc();
auto llvmArrayTy = operands[0].getType().cast<LLVM::LLVMType>();
if (!llvmArrayTy.isArrayTy()) {
@@ -1202,31 +1233,15 @@ struct NaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
return this->matchSuccess();
}
- auto vectorType = op->getResult(0).getType().dyn_cast<VectorType>();
- if (!vectorType)
- return this->matchFailure();
- auto vectorTypeInfo =
- extractNDVectorTypeInfo(vectorType, this->typeConverter);
- auto llvmVectorTy = vectorTypeInfo.llvmVectorTy;
- if (!llvmVectorTy || llvmArrayTy != vectorTypeInfo.llvmArrayTy)
- return this->matchFailure();
-
- Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayTy);
- nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) {
- // For this unrolled `position` corresponding to the `linearIndex`^th
- // element, extract operand vectors
- SmallVector<Value, OpCount> extractedOperands;
- for (unsigned i = 0; i < OpCount; ++i) {
- extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>(
- loc, llvmVectorTy, operands[i], position));
- }
- Value newVal = rewriter.create<TargetOp>(
- loc, llvmVectorTy, extractedOperands, op->getAttrs());
- desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayTy, desc,
- newVal, position);
- });
- rewriter.replaceOp(op, desc);
- return this->matchSuccess();
+ if (succeeded(HandleMultidimensionalVectors(
+ op, operands, this->typeConverter,
+ [&](LLVM::LLVMType llvmVectorTy, ValueRange operands) {
+ return rewriter.create<TargetOp>(op->getLoc(), llvmVectorTy,
+ operands, op->getAttrs());
+ },
+ rewriter)))
+ return this->matchSuccess();
+ return this->matchFailure();
}
};
@@ -1673,7 +1688,7 @@ struct RsqrtOpLowering : public LLVMLegalizationPattern<RsqrtOp> {
ConversionPatternRewriter &rewriter) const override {
OperandAdaptor<RsqrtOp> transformed(operands);
auto operandType =
- transformed.operand().getType().dyn_cast_or_null<LLVM::LLVMType>();
+ transformed.operand().getType().dyn_cast<LLVM::LLVMType>();
if (!operandType)
return matchFailure();
@@ -1694,41 +1709,31 @@ struct RsqrtOpLowering : public LLVMLegalizationPattern<RsqrtOp> {
}
auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, transformed.operand());
rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, operandType, one, sqrt);
- return matchSuccess();
+ return this->matchSuccess();
}
auto vectorType = resultType.dyn_cast<VectorType>();
if (!vectorType)
return this->matchFailure();
- auto vectorTypeInfo =
- extractNDVectorTypeInfo(vectorType, this->typeConverter);
- auto llvmVectorTy = vectorTypeInfo.llvmVectorTy;
- if (!llvmVectorTy || operandType != vectorTypeInfo.llvmArrayTy)
- return this->matchFailure();
-
- Value desc = rewriter.create<LLVM::UndefOp>(loc, operandType);
- nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) {
- // For this unrolled `position` corresponding to the `linearIndex`^th
- // element, extract operand vectors
- auto extractedOperand = rewriter.create<LLVM::ExtractValueOp>(
- loc, llvmVectorTy, operands[0], position);
- auto splatAttr = SplatElementsAttr::get(
- mlir::VectorType::get(
- {llvmVectorTy.getUnderlyingType()->getVectorNumElements()},
- floatType),
- floatOne);
- auto one =
- rewriter.create<LLVM::ConstantOp>(loc, llvmVectorTy, splatAttr);
- auto sqrt =
- rewriter.create<LLVM::SqrtOp>(loc, llvmVectorTy, extractedOperand);
- auto div = rewriter.create<LLVM::FDivOp>(loc, llvmVectorTy, one, sqrt);
- desc = rewriter.create<LLVM::InsertValueOp>(loc, operandType, desc, div,
- position);
- });
- rewriter.replaceOp(op, desc);
-
- return matchSuccess();
+ if (succeeded(HandleMultidimensionalVectors(
+ op, operands, typeConverter,
+ [&](LLVM::LLVMType llvmVectorTy, ValueRange operands) {
+ auto splatAttr = SplatElementsAttr::get(
+ mlir::VectorType::get({llvmVectorTy.getUnderlyingType()
+ ->getVectorNumElements()},
+ floatType),
+ floatOne);
+ auto one = rewriter.create<LLVM::ConstantOp>(loc, llvmVectorTy,
+ splatAttr);
+ auto sqrt =
+ rewriter.create<LLVM::SqrtOp>(loc, llvmVectorTy, operands[0]);
+ return rewriter.create<LLVM::FDivOp>(loc, llvmVectorTy, one,
+ sqrt);
+ },
+ rewriter)))
+ return this->matchSuccess();
+ return this->matchFailure();
}
};
@@ -1745,7 +1750,7 @@ struct TanhOpLowering : public LLVMLegalizationPattern<TanhOp> {
OperandAdaptor<TanhOp> transformed(operands);
LLVMTypeT operandType =
- transformed.operand().getType().dyn_cast_or_null<LLVM::LLVMType>();
+ transformed.operand().getType().dyn_cast<LLVM::LLVMType>();
if (!operandType)
return matchFailure();
More information about the Mlir-commits
mailing list