[Mlir-commits] [mlir] 96a7900 - [mlir] Fix multidimensional lowering from std.select to llvm.select
Benjamin Kramer
llvmlistbot at llvm.org
Mon May 3 10:31:03 PDT 2021
Author: Benjamin Kramer
Date: 2021-05-03T19:30:49+02:00
New Revision: 96a7900eb065385e1f61d926affaed13af446962
URL: https://github.com/llvm/llvm-project/commit/96a7900eb065385e1f61d926affaed13af446962
DIFF: https://github.com/llvm/llvm-project/commit/96a7900eb065385e1f61d926affaed13af446962.diff
LOG: [mlir] Fix multidimensional lowering from std.select to llvm.select
The converter assumed that all operands have the same type, that's not
true for select.
Differential Revision: https://reviews.llvm.org/D101767
Added:
Modified:
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 803ea52fa717..8df65042e154 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -1643,16 +1643,18 @@ static LogicalResult handleMultidimensionalVectors(
Operation *op, ValueRange operands, LLVMTypeConverter &typeConverter,
std::function<Value(Type, ValueRange)> createOperand,
ConversionPatternRewriter &rewriter) {
- auto operandNDVectorType = op->getOperand(0).getType().dyn_cast<VectorType>();
- auto resultNDVectorType = op->getResult(0).getType().dyn_cast<VectorType>();
- assert(operandNDVectorType && resultNDVectorType && "expected vector types");
-
+ auto resultNDVectorType = op->getResult(0).getType().cast<VectorType>();
+
+ SmallVector<Type> operand1DVectorTypes;
+ for (Value operand : op->getOperands()) {
+ auto operandNDVectorType = operand.getType().cast<VectorType>();
+ auto operandTypeInfo =
+ extractNDVectorTypeInfo(operandNDVectorType, typeConverter);
+ operand1DVectorTypes.push_back(operandTypeInfo.llvm1DVectorTy);
+ }
auto resultTypeInfo =
extractNDVectorTypeInfo(resultNDVectorType, typeConverter);
- auto operandTypeInfo =
- extractNDVectorTypeInfo(operandNDVectorType, typeConverter);
auto result1DVectorTy = resultTypeInfo.llvm1DVectorTy;
- auto operand1DVectorTy = operandTypeInfo.llvm1DVectorTy;
auto resultNDVectoryTy = resultTypeInfo.llvmNDVectorTy;
auto loc = op->getLoc();
Value desc = rewriter.create<LLVM::UndefOp>(loc, resultNDVectoryTy);
@@ -1660,9 +1662,11 @@ static LogicalResult handleMultidimensionalVectors(
// For this unrolled `position` corresponding to the `linearIndex`^th
// element, extract operand vectors
SmallVector<Value, 4> extractedOperands;
- for (auto operand : operands)
+ for (auto operand : llvm::enumerate(operands)) {
extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>(
- loc, operand1DVectorTy, operand, position));
+ loc, operand1DVectorTypes[operand.index()], operand.value(),
+ position));
+ }
Value newVal = createOperand(result1DVectorTy, extractedOperands);
desc = rewriter.create<LLVM::InsertValueOp>(loc, resultNDVectoryTy, desc,
newVal, position);
@@ -1723,7 +1727,7 @@ using NegFOpLowering = VectorConvertToLLVMPattern<NegFOp, LLVM::FNegOp>;
using OrOpLowering = VectorConvertToLLVMPattern<OrOp, LLVM::OrOp>;
using PowFOpLowering = VectorConvertToLLVMPattern<math::PowFOp, LLVM::PowOp>;
using RemFOpLowering = VectorConvertToLLVMPattern<RemFOp, LLVM::FRemOp>;
-using SelectOpLowering = OneToOneConvertToLLVMPattern<SelectOp, LLVM::SelectOp>;
+using SelectOpLowering = VectorConvertToLLVMPattern<SelectOp, LLVM::SelectOp>;
using SignExtendIOpLowering =
VectorConvertToLLVMPattern<SignExtendIOp, LLVM::SExtOp>;
using ShiftLeftOpLowering =
diff --git a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
index 87f0b01c5210..4a2983b44eb9 100644
--- a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
@@ -297,3 +297,16 @@ func @cmpi_2dvector(%arg0 : vector<4x3xi32>, %arg1 : vector<4x3xi32>) {
%0 = cmpi ult, %arg0, %arg1 : vector<4x3xi32>
std.return
}
+
+// -----
+
+// CHECK-LABEL: func @select_2dvector(
+func @select_2dvector(%arg0 : vector<4x3xi1>, %arg1 : vector<4x3xi32>, %arg2 : vector<4x3xi32>) {
+ // CHECK: %[[EXTRACT1:.*]] = llvm.extractvalue %arg0[0] : !llvm.array<4 x vector<3xi1>>
+ // CHECK: %[[EXTRACT2:.*]] = llvm.extractvalue %arg1[0] : !llvm.array<4 x vector<3xi32>>
+ // CHECK: %[[EXTRACT3:.*]] = llvm.extractvalue %arg2[0] : !llvm.array<4 x vector<3xi32>>
+ // CHECK: %[[SELECT:.*]] = llvm.select %[[EXTRACT1]], %[[EXTRACT2]], %[[EXTRACT3]] : vector<3xi1>, vector<3xi32>
+ // CHECK: %[[INSERT:.*]] = llvm.insertvalue %[[SELECT]], %0[0] : !llvm.array<4 x vector<3xi32>>
+ %0 = select %arg0, %arg1, %arg2 : vector<4x3xi1>, vector<4x3xi32>
+ std.return
+}
More information about the Mlir-commits
mailing list