[Mlir-commits] [mlir] 96a7900 - [mlir] Fix multidimensional lowering from std.select to llvm.select

Benjamin Kramer llvmlistbot at llvm.org
Mon May 3 10:31:03 PDT 2021


Author: Benjamin Kramer
Date: 2021-05-03T19:30:49+02:00
New Revision: 96a7900eb065385e1f61d926affaed13af446962

URL: https://github.com/llvm/llvm-project/commit/96a7900eb065385e1f61d926affaed13af446962
DIFF: https://github.com/llvm/llvm-project/commit/96a7900eb065385e1f61d926affaed13af446962.diff

LOG: [mlir] Fix multidimensional lowering from std.select to llvm.select

The converter assumed that all operands have the same type, that's not
true for select.

Differential Revision: https://reviews.llvm.org/D101767

Added: 
    

Modified: 
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
    mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 803ea52fa717..8df65042e154 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -1643,16 +1643,18 @@ static LogicalResult handleMultidimensionalVectors(
     Operation *op, ValueRange operands, LLVMTypeConverter &typeConverter,
     std::function<Value(Type, ValueRange)> createOperand,
     ConversionPatternRewriter &rewriter) {
-  auto operandNDVectorType = op->getOperand(0).getType().dyn_cast<VectorType>();
-  auto resultNDVectorType = op->getResult(0).getType().dyn_cast<VectorType>();
-  assert(operandNDVectorType && resultNDVectorType && "expected vector types");
-
+  auto resultNDVectorType = op->getResult(0).getType().cast<VectorType>();
+
+  SmallVector<Type> operand1DVectorTypes;
+  for (Value operand : op->getOperands()) {
+    auto operandNDVectorType = operand.getType().cast<VectorType>();
+    auto operandTypeInfo =
+        extractNDVectorTypeInfo(operandNDVectorType, typeConverter);
+    operand1DVectorTypes.push_back(operandTypeInfo.llvm1DVectorTy);
+  }
   auto resultTypeInfo =
       extractNDVectorTypeInfo(resultNDVectorType, typeConverter);
-  auto operandTypeInfo =
-      extractNDVectorTypeInfo(operandNDVectorType, typeConverter);
   auto result1DVectorTy = resultTypeInfo.llvm1DVectorTy;
-  auto operand1DVectorTy = operandTypeInfo.llvm1DVectorTy;
   auto resultNDVectoryTy = resultTypeInfo.llvmNDVectorTy;
   auto loc = op->getLoc();
   Value desc = rewriter.create<LLVM::UndefOp>(loc, resultNDVectoryTy);
@@ -1660,9 +1662,11 @@ static LogicalResult handleMultidimensionalVectors(
     // For this unrolled `position` corresponding to the `linearIndex`^th
     // element, extract operand vectors
     SmallVector<Value, 4> extractedOperands;
-    for (auto operand : operands)
+    for (auto operand : llvm::enumerate(operands)) {
       extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>(
-          loc, operand1DVectorTy, operand, position));
+          loc, operand1DVectorTypes[operand.index()], operand.value(),
+          position));
+    }
     Value newVal = createOperand(result1DVectorTy, extractedOperands);
     desc = rewriter.create<LLVM::InsertValueOp>(loc, resultNDVectoryTy, desc,
                                                 newVal, position);
@@ -1723,7 +1727,7 @@ using NegFOpLowering = VectorConvertToLLVMPattern<NegFOp, LLVM::FNegOp>;
 using OrOpLowering = VectorConvertToLLVMPattern<OrOp, LLVM::OrOp>;
 using PowFOpLowering = VectorConvertToLLVMPattern<math::PowFOp, LLVM::PowOp>;
 using RemFOpLowering = VectorConvertToLLVMPattern<RemFOp, LLVM::FRemOp>;
-using SelectOpLowering = OneToOneConvertToLLVMPattern<SelectOp, LLVM::SelectOp>;
+using SelectOpLowering = VectorConvertToLLVMPattern<SelectOp, LLVM::SelectOp>;
 using SignExtendIOpLowering =
     VectorConvertToLLVMPattern<SignExtendIOp, LLVM::SExtOp>;
 using ShiftLeftOpLowering =

diff  --git a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
index 87f0b01c5210..4a2983b44eb9 100644
--- a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
@@ -297,3 +297,16 @@ func @cmpi_2dvector(%arg0 : vector<4x3xi32>, %arg1 : vector<4x3xi32>) {
   %0 = cmpi ult, %arg0, %arg1 : vector<4x3xi32>
   std.return
 }
+
+// -----
+
+// CHECK-LABEL: func @select_2dvector(
+func @select_2dvector(%arg0 : vector<4x3xi1>, %arg1 : vector<4x3xi32>, %arg2 : vector<4x3xi32>) {
+  // CHECK: %[[EXTRACT1:.*]] = llvm.extractvalue %arg0[0] : !llvm.array<4 x vector<3xi1>>
+  // CHECK: %[[EXTRACT2:.*]] = llvm.extractvalue %arg1[0] : !llvm.array<4 x vector<3xi32>>
+  // CHECK: %[[EXTRACT3:.*]] = llvm.extractvalue %arg2[0] : !llvm.array<4 x vector<3xi32>>
+  // CHECK: %[[SELECT:.*]] = llvm.select %[[EXTRACT1]], %[[EXTRACT2]], %[[EXTRACT3]] : vector<3xi1>, vector<3xi32>
+  // CHECK: %[[INSERT:.*]] = llvm.insertvalue %[[SELECT]], %0[0] : !llvm.array<4 x vector<3xi32>>
+  %0 = select %arg0, %arg1, %arg2 : vector<4x3xi1>, vector<4x3xi32>
+  std.return
+}


        


More information about the Mlir-commits mailing list