[Mlir-commits] [mlir] 49c9c3a - [mlir][Standard] Extend n-D vector lowering to LLVM to [s|z]exti ops.
Nicolas Vasilache
llvmlistbot at llvm.org
Mon Feb 1 23:45:58 PST 2021
Author: Nicolas Vasilache
Date: 2021-02-02T07:45:50Z
New Revision: 49c9c3a59e21205aabe48f92b9e53bf666348388
URL: https://github.com/llvm/llvm-project/commit/49c9c3a59e21205aabe48f92b9e53bf666348388
DIFF: https://github.com/llvm/llvm-project/commit/49c9c3a59e21205aabe48f92b9e53bf666348388.diff
LOG: [mlir][Standard] Extend n-D vector lowering to LLVM to [s|z]exti ops.
[s|z]exti ops do not have the same operand and result type.
As a consequence, the lowering of the n-D vector form needs to be relaxed a bit.
This revision additionally performs a few NFC renamings of variables to make them more intuitive.
Differential Revision: https://reviews.llvm.org/D95760
Added:
mlir/test/Conversion/StandardToLLVM/convert-nd-vector-to-llvmir.mlir
Modified:
mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
index 357bd2f021b1..90ebd94c4207 100644
--- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
+++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
@@ -656,9 +656,6 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
static_assert(
std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
"expected single result op");
- static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
- SourceOp>::value,
- "expected same operands and result type");
return LLVM::detail::vectorOneToOneRewrite(
op, TargetOp::getOperationName(), operands, *this->getTypeConverter(),
rewriter);
diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index c59cbdab8fb7..bb6376cf20f3 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -1472,10 +1472,10 @@ struct BarePtrFuncOpConversion : public FuncOpConversionBase {
// 1-D LLVM vectors.
struct NDVectorTypeInfo {
// LLVM array struct which encodes n-D vectors.
- Type llvmArrayTy;
+ Type llvmNDVectorTy;
// LLVM vector type which encodes the inner 1-D vector type.
- Type llvmVectorTy;
- // Multiplicity of llvmArrayTy to llvmVectorTy.
+ Type llvm1DVectorTy;
+ // Multiplicity of llvmNDVectorTy to llvm1DVectorTy.
SmallVector<int64_t, 4> arraySizes;
};
} // namespace
@@ -1488,13 +1488,13 @@ static NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType,
LLVMTypeConverter &converter) {
assert(vectorType.getRank() > 1 && "expected >1D vector type");
NDVectorTypeInfo info;
- info.llvmArrayTy = converter.convertType(vectorType);
- if (!info.llvmArrayTy || !LLVM::isCompatibleType(info.llvmArrayTy)) {
- info.llvmArrayTy = nullptr;
+ info.llvmNDVectorTy = converter.convertType(vectorType);
+ if (!info.llvmNDVectorTy || !LLVM::isCompatibleType(info.llvmNDVectorTy)) {
+ info.llvmNDVectorTy = nullptr;
return info;
}
info.arraySizes.reserve(vectorType.getRank() - 1);
- auto llvmTy = info.llvmArrayTy;
+ auto llvmTy = info.llvmNDVectorTy;
while (llvmTy.isa<LLVM::LLVMArrayType>()) {
info.arraySizes.push_back(
llvmTy.cast<LLVM::LLVMArrayType>().getNumElements());
@@ -1502,7 +1502,7 @@ static NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType,
}
if (!LLVM::isCompatibleVectorType(llvmTy))
return info;
- info.llvmVectorTy = llvmTy;
+ info.llvm1DVectorTy = llvmTy;
return info;
}
@@ -1591,27 +1591,29 @@ static LogicalResult handleMultidimensionalVectors(
Operation *op, ValueRange operands, LLVMTypeConverter &typeConverter,
std::function<Value(Type, ValueRange)> createOperand,
ConversionPatternRewriter &rewriter) {
- auto vectorType = op->getResult(0).getType().dyn_cast<VectorType>();
- if (!vectorType)
- return failure();
- auto vectorTypeInfo = extractNDVectorTypeInfo(vectorType, typeConverter);
- auto llvmVectorTy = vectorTypeInfo.llvmVectorTy;
- auto llvmArrayTy = operands[0].getType();
- if (!llvmVectorTy || llvmArrayTy != vectorTypeInfo.llvmArrayTy)
- return failure();
-
+ auto operandNDVectorType = op->getOperand(0).getType().dyn_cast<VectorType>();
+ auto resultNDVectorType = op->getResult(0).getType().dyn_cast<VectorType>();
+ assert(operandNDVectorType && resultNDVectorType && "expected vector types");
+
+ auto resultTypeInfo =
+ extractNDVectorTypeInfo(resultNDVectorType, typeConverter);
+ auto operandTypeInfo =
+ extractNDVectorTypeInfo(operandNDVectorType, typeConverter);
+ auto result1DVectorTy = resultTypeInfo.llvm1DVectorTy;
+ auto operand1DVectorTy = operandTypeInfo.llvm1DVectorTy;
+ auto resultNDVectoryTy = resultTypeInfo.llvmNDVectorTy;
auto loc = op->getLoc();
- Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayTy);
- nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) {
+ Value desc = rewriter.create<LLVM::UndefOp>(loc, resultNDVectoryTy);
+ nDVectorIterate(resultTypeInfo, rewriter, [&](ArrayAttr position) {
// For this unrolled `position` corresponding to the `linearIndex`^th
// element, extract operand vectors
SmallVector<Value, 4> extractedOperands;
for (auto operand : operands)
extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>(
- loc, llvmVectorTy, operand, position));
- Value newVal = createOperand(llvmVectorTy, extractedOperands);
- desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayTy, desc, newVal,
- position);
+ loc, operand1DVectorTy, operand, position));
+ Value newVal = createOperand(result1DVectorTy, extractedOperands);
+ desc = rewriter.create<LLVM::InsertValueOp>(loc, resultNDVectoryTy, desc,
+ newVal, position);
});
rewriter.replaceOp(op, desc);
return success();
@@ -1627,14 +1629,14 @@ LogicalResult LLVM::detail::vectorOneToOneRewrite(
[](Type t) { return isCompatibleType(t); }))
return failure();
- auto llvmArrayTy = operands[0].getType();
- if (!llvmArrayTy.isa<LLVM::LLVMArrayType>())
+ auto llvmNDVectorTy = operands[0].getType();
+ if (!llvmNDVectorTy.isa<LLVM::LLVMArrayType>())
return oneToOneRewrite(op, targetOp, operands, typeConverter, rewriter);
- auto callback = [op, targetOp, &rewriter](Type llvmVectorTy,
+ auto callback = [op, targetOp, &rewriter](Type llvm1DVectorTy,
ValueRange operands) {
OperationState state(op->getLoc(), targetOp);
- state.addTypes(llvmVectorTy);
+ state.addTypes(llvm1DVectorTy);
state.addOperands(operands);
state.addAttributes(op->getAttrs());
return rewriter.createOperation(state)->getResult(0);
@@ -1668,6 +1670,8 @@ using OrOpLowering = VectorConvertToLLVMPattern<OrOp, LLVM::OrOp>;
using PowFOpLowering = VectorConvertToLLVMPattern<PowFOp, LLVM::PowOp>;
using RemFOpLowering = VectorConvertToLLVMPattern<RemFOp, LLVM::FRemOp>;
using SelectOpLowering = OneToOneConvertToLLVMPattern<SelectOp, LLVM::SelectOp>;
+using SignExtendIOpLowering =
+ VectorConvertToLLVMPattern<SignExtendIOp, LLVM::SExtOp>;
using ShiftLeftOpLowering =
OneToOneConvertToLLVMPattern<ShiftLeftOp, LLVM::ShlOp>;
using SignedDivIOpLowering =
@@ -1687,6 +1691,8 @@ using UnsignedRemIOpLowering =
using UnsignedShiftRightOpLowering =
OneToOneConvertToLLVMPattern<UnsignedShiftRightOp, LLVM::LShrOp>;
using XOrOpLowering = VectorConvertToLLVMPattern<XOrOp, LLVM::XOrOp>;
+using ZeroExtendIOpLowering =
+ VectorConvertToLLVMPattern<ZeroExtendIOp, LLVM::ZExtOp>;
/// Lower `std.assert`. The default lowering calls the `abort` function if the
/// assertion is violated and has no effect otherwise. The failure message is
@@ -2366,17 +2372,17 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<RsqrtOp> {
return handleMultidimensionalVectors(
op.getOperation(), operands, *getTypeConverter(),
- [&](Type llvmVectorTy, ValueRange operands) {
+ [&](Type llvm1DVectorTy, ValueRange operands) {
auto splatAttr = SplatElementsAttr::get(
mlir::VectorType::get(
- {LLVM::getVectorNumElements(llvmVectorTy).getFixedValue()},
+ {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
floatType),
floatOne);
auto one =
- rewriter.create<LLVM::ConstantOp>(loc, llvmVectorTy, splatAttr);
+ rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
auto sqrt =
- rewriter.create<LLVM::SqrtOp>(loc, llvmVectorTy, operands[0]);
- return rewriter.create<LLVM::FDivOp>(loc, llvmVectorTy, one, sqrt);
+ rewriter.create<LLVM::SqrtOp>(loc, llvm1DVectorTy, operands[0]);
+ return rewriter.create<LLVM::FDivOp>(loc, llvm1DVectorTy, one, sqrt);
},
rewriter);
}
@@ -3050,21 +3056,11 @@ struct FPTruncLowering
using Super::Super;
};
-struct SignExtendIOpLowering
- : public OneToOneConvertToLLVMPattern<SignExtendIOp, LLVM::SExtOp> {
- using Super::Super;
-};
-
struct TruncateIOpLowering
: public OneToOneConvertToLLVMPattern<TruncateIOp, LLVM::TruncOp> {
using Super::Super;
};
-struct ZeroExtendIOpLowering
- : public OneToOneConvertToLLVMPattern<ZeroExtendIOp, LLVM::ZExtOp> {
- using Super::Super;
-};
-
// Base class for LLVM IR lowering terminator operations with successors.
template <typename SourceOp, typename TargetOp>
struct OneToOneLLVMTerminatorLowering
@@ -3211,21 +3207,21 @@ struct SplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
auto loc = splatOp.getLoc();
auto vectorTypeInfo =
extractNDVectorTypeInfo(resultType, *getTypeConverter());
- auto llvmArrayTy = vectorTypeInfo.llvmArrayTy;
- auto llvmVectorTy = vectorTypeInfo.llvmVectorTy;
- if (!llvmArrayTy || !llvmVectorTy)
+ auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
+ auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy;
+ if (!llvmNDVectorTy || !llvm1DVectorTy)
return failure();
// Construct returned value.
- Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayTy);
+ Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmNDVectorTy);
// Construct a 1-D vector with the splatted value that we insert in all the
// places within the returned descriptor.
- Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvmVectorTy);
+ Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvm1DVectorTy);
auto zero = rewriter.create<LLVM::ConstantOp>(
loc, typeConverter->convertType(rewriter.getIntegerType(32)),
rewriter.getZeroAttr(rewriter.getIntegerType(32)));
- Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvmVectorTy, vdesc,
+ Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc,
adaptor.input(), zero);
// Shuffle the value across the desired number of elements.
@@ -3237,7 +3233,7 @@ struct SplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
// Iterate of linear index, convert to coords space and insert splatted 1-D
// vector in each position.
nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) {
- desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayTy, desc, v,
+ desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmNDVectorTy, desc, v,
position);
});
rewriter.replaceOp(splatOp, desc);
diff --git a/mlir/test/Conversion/StandardToLLVM/convert-nd-vector-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-nd-vector-to-llvmir.mlir
new file mode 100644
index 000000000000..dce630dfcf23
--- /dev/null
+++ b/mlir/test/Conversion/StandardToLLVM/convert-nd-vector-to-llvmir.mlir
@@ -0,0 +1,49 @@
+// RUN: mlir-opt -convert-std-to-llvm %s -split-input-file | FileCheck %s
+
+// CHECK-LABEL: @vec_bin
+func @vec_bin(%arg0: vector<2x2x2xf32>) -> vector<2x2x2xf32> {
+ %0 = addf %arg0, %arg0 : vector<2x2x2xf32>
+ return %0 : vector<2x2x2xf32>
+
+// CHECK-NEXT: llvm.mlir.undef : !llvm.array<2 x array<2 x vector<2xf32>>>
+
+// This block appears 2x2 times
+// CHECK-NEXT: llvm.extractvalue %{{.*}}[0, 0] : !llvm.array<2 x array<2 x vector<2xf32>>>
+// CHECK-NEXT: llvm.extractvalue %{{.*}}[0, 0] : !llvm.array<2 x array<2 x vector<2xf32>>>
+// CHECK-NEXT: llvm.fadd %{{.*}} : vector<2xf32>
+// CHECK-NEXT: llvm.insertvalue %{{.*}}[0, 0] : !llvm.array<2 x array<2 x vector<2xf32>>>
+
+// We check the proper indexing of extract/insert in the remaining 3 positions.
+// CHECK: llvm.extractvalue %{{.*}}[0, 1] : !llvm.array<2 x array<2 x vector<2xf32>>>
+// CHECK: llvm.insertvalue %{{.*}}[0, 1] : !llvm.array<2 x array<2 x vector<2xf32>>>
+// CHECK: llvm.extractvalue %{{.*}}[1, 0] : !llvm.array<2 x array<2 x vector<2xf32>>>
+// CHECK: llvm.insertvalue %{{.*}}[1, 0] : !llvm.array<2 x array<2 x vector<2xf32>>>
+// CHECK: llvm.extractvalue %{{.*}}[1, 1] : !llvm.array<2 x array<2 x vector<2xf32>>>
+// CHECK: llvm.insertvalue %{{.*}}[1, 1] : !llvm.array<2 x array<2 x vector<2xf32>>>
+}
+
+// CHECK-LABEL: @sexti
+func @sexti_vector(%arg0 : vector<1x2x3xi32>, %arg1 : vector<1x2x3xi64>) {
+ // CHECK: llvm.mlir.undef : !llvm.array<1 x array<2 x vector<3xi64>>>
+ // CHECK: llvm.extractvalue %{{.*}}[0, 0] : !llvm.array<1 x array<2 x vector<3xi32>>>
+ // CHECK: llvm.sext %{{.*}} : vector<3xi32> to vector<3xi64>
+ // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0, 0] : !llvm.array<1 x array<2 x vector<3xi64>>>
+ // CHECK: llvm.extractvalue %{{.*}}[0, 1] : !llvm.array<1 x array<2 x vector<3xi32>>>
+ // CHECK: llvm.sext %{{.*}} : vector<3xi32> to vector<3xi64>
+ // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0, 1] : !llvm.array<1 x array<2 x vector<3xi64>>>
+ %0 = sexti %arg0: vector<1x2x3xi32> to vector<1x2x3xi64>
+ return
+}
+
+// CHECK-LABEL: @zexti
+func @zexti_vector(%arg0 : vector<1x2x3xi32>, %arg1 : vector<1x2x3xi64>) {
+ // CHECK: llvm.mlir.undef : !llvm.array<1 x array<2 x vector<3xi64>>>
+ // CHECK: llvm.extractvalue %{{.*}}[0, 0] : !llvm.array<1 x array<2 x vector<3xi32>>>
+ // CHECK: llvm.zext %{{.*}} : vector<3xi32> to vector<3xi64>
+ // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0, 0] : !llvm.array<1 x array<2 x vector<3xi64>>>
+ // CHECK: llvm.extractvalue %{{.*}}[0, 1] : !llvm.array<1 x array<2 x vector<3xi32>>>
+ // CHECK: llvm.zext %{{.*}} : vector<3xi32> to vector<3xi64>
+ // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0, 1] : !llvm.array<1 x array<2 x vector<3xi64>>>
+ %0 = zexti %arg0: vector<1x2x3xi32> to vector<1x2x3xi64>
+ return
+}
diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
index 749d733b55f2..5081e7b910da 100644
--- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
@@ -766,31 +766,6 @@ func @fcmp(f32, f32) -> () {
return
}
-// CHECK-LABEL: @vec_bin
-func @vec_bin(%arg0: vector<2x2x2xf32>) -> vector<2x2x2xf32> {
- %0 = addf %arg0, %arg0 : vector<2x2x2xf32>
- return %0 : vector<2x2x2xf32>
-
-// CHECK-NEXT: llvm.mlir.undef : !llvm.array<2 x array<2 x vector<2xf32>>>
-
-// This block appears 2x2 times
-// CHECK-NEXT: llvm.extractvalue %{{.*}}[0, 0] : !llvm.array<2 x array<2 x vector<2xf32>>>
-// CHECK-NEXT: llvm.extractvalue %{{.*}}[0, 0] : !llvm.array<2 x array<2 x vector<2xf32>>>
-// CHECK-NEXT: llvm.fadd %{{.*}} : vector<2xf32>
-// CHECK-NEXT: llvm.insertvalue %{{.*}}[0, 0] : !llvm.array<2 x array<2 x vector<2xf32>>>
-
-// We check the proper indexing of extract/insert in the remaining 3 positions.
-// CHECK: llvm.extractvalue %{{.*}}[0, 1] : !llvm.array<2 x array<2 x vector<2xf32>>>
-// CHECK: llvm.insertvalue %{{.*}}[0, 1] : !llvm.array<2 x array<2 x vector<2xf32>>>
-// CHECK: llvm.extractvalue %{{.*}}[1, 0] : !llvm.array<2 x array<2 x vector<2xf32>>>
-// CHECK: llvm.insertvalue %{{.*}}[1, 0] : !llvm.array<2 x array<2 x vector<2xf32>>>
-// CHECK: llvm.extractvalue %{{.*}}[1, 1] : !llvm.array<2 x array<2 x vector<2xf32>>>
-// CHECK: llvm.insertvalue %{{.*}}[1, 1] : !llvm.array<2 x array<2 x vector<2xf32>>>
-
-// And we're done
-// CHECK-NEXT: return
-}
-
// CHECK-LABEL: @splat
// CHECK-SAME: %[[A:arg[0-9]+]]: vector<4xf32>
// CHECK-SAME: %[[ELT:arg[0-9]+]]: f32
More information about the Mlir-commits
mailing list