[Mlir-commits] [mlir] cdeb4a8 - [mlir] Allow lowering cmpi/cmpf with multidimensional vectors to LLVM
Benjamin Kramer
llvmlistbot at llvm.org
Mon May 3 02:30:47 PDT 2021
Author: Benjamin Kramer
Date: 2021-05-03T11:30:21+02:00
New Revision: cdeb4a8a6430941563c9f04d178d4c5211164aa3
URL: https://github.com/llvm/llvm-project/commit/cdeb4a8a6430941563c9f04d178d4c5211164aa3
DIFF: https://github.com/llvm/llvm-project/commit/cdeb4a8a6430941563c9f04d178d4c5211164aa3.diff
LOG: [mlir] Allow lowering cmpi/cmpf with multidimensional vectors to LLVM
Differential Revision: https://reviews.llvm.org/D101535
Added:
Modified:
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 32ff6b526f2e1..803ea52fa717d 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -3065,11 +3065,32 @@ struct CmpIOpLowering : public ConvertOpToLLVMPattern<CmpIOp> {
matchAndRewrite(CmpIOp cmpiOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
CmpIOpAdaptor transformed(operands);
+ auto operandType = transformed.lhs().getType();
+ auto resultType = cmpiOp.getResult().getType();
- rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
- cmpiOp, typeConverter->convertType(cmpiOp.getResult().getType()),
- convertCmpPredicate<LLVM::ICmpPredicate>(cmpiOp.getPredicate()),
- transformed.lhs(), transformed.rhs());
+ // Handle the scalar and 1D vector cases.
+ if (!operandType.isa<LLVM::LLVMArrayType>()) {
+ rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
+ cmpiOp, typeConverter->convertType(resultType),
+ convertCmpPredicate<LLVM::ICmpPredicate>(cmpiOp.getPredicate()),
+ transformed.lhs(), transformed.rhs());
+ return success();
+ }
+
+ auto vectorType = resultType.dyn_cast<VectorType>();
+ if (!vectorType)
+ return rewriter.notifyMatchFailure(cmpiOp, "expected vector result type");
+
+ return handleMultidimensionalVectors(
+ cmpiOp.getOperation(), operands, *getTypeConverter(),
+ [&](Type llvm1DVectorTy, ValueRange operands) {
+ CmpIOpAdaptor transformed(operands);
+ return rewriter.create<LLVM::ICmpOp>(
+ cmpiOp.getLoc(), llvm1DVectorTy,
+ convertCmpPredicate<LLVM::ICmpPredicate>(cmpiOp.getPredicate()),
+ transformed.lhs(), transformed.rhs());
+ },
+ rewriter);
return success();
}
@@ -3082,13 +3103,32 @@ struct CmpFOpLowering : public ConvertOpToLLVMPattern<CmpFOp> {
matchAndRewrite(CmpFOp cmpfOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
CmpFOpAdaptor transformed(operands);
+ auto operandType = transformed.lhs().getType();
+ auto resultType = cmpfOp.getResult().getType();
- rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
- cmpfOp, typeConverter->convertType(cmpfOp.getResult().getType()),
- convertCmpPredicate<LLVM::FCmpPredicate>(cmpfOp.getPredicate()),
- transformed.lhs(), transformed.rhs());
+ // Handle the scalar and 1D vector cases.
+ if (!operandType.isa<LLVM::LLVMArrayType>()) {
+ rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
+ cmpfOp, typeConverter->convertType(resultType),
+ convertCmpPredicate<LLVM::FCmpPredicate>(cmpfOp.getPredicate()),
+ transformed.lhs(), transformed.rhs());
+ return success();
+ }
- return success();
+ auto vectorType = resultType.dyn_cast<VectorType>();
+ if (!vectorType)
+ return rewriter.notifyMatchFailure(cmpfOp, "expected vector result type");
+
+ return handleMultidimensionalVectors(
+ cmpfOp.getOperation(), operands, *getTypeConverter(),
+ [&](Type llvm1DVectorTy, ValueRange operands) {
+ CmpFOpAdaptor transformed(operands);
+ return rewriter.create<LLVM::FCmpOp>(
+ cmpfOp.getLoc(), llvm1DVectorTy,
+ convertCmpPredicate<LLVM::FCmpPredicate>(cmpfOp.getPredicate()),
+ transformed.lhs(), transformed.rhs());
+ },
+ rewriter);
}
};
diff --git a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
index b53de92935fe3..87f0b01c52100 100644
--- a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
@@ -274,3 +274,26 @@ func @index_vector(%arg0: vector<4xindex>) {
std.return
}
+// -----
+
+// CHECK-LABEL: func @cmpf_2dvector(
+func @cmpf_2dvector(%arg0 : vector<4x3xf32>, %arg1 : vector<4x3xf32>) {
+ // CHECK: %[[EXTRACT1:.*]] = llvm.extractvalue %arg0[0] : !llvm.array<4 x vector<3xf32>>
+ // CHECK: %[[EXTRACT2:.*]] = llvm.extractvalue %arg1[0] : !llvm.array<4 x vector<3xf32>>
+ // CHECK: %[[CMP:.*]] = llvm.fcmp "olt" %[[EXTRACT1]], %[[EXTRACT2]] : vector<3xf32>
+ // CHECK: %[[INSERT:.*]] = llvm.insertvalue %[[CMP]], %0[0] : !llvm.array<4 x vector<3xi1>>
+ %0 = cmpf olt, %arg0, %arg1 : vector<4x3xf32>
+ std.return
+}
+
+// -----
+
+// CHECK-LABEL: func @cmpi_2dvector(
+func @cmpi_2dvector(%arg0 : vector<4x3xi32>, %arg1 : vector<4x3xi32>) {
+ // CHECK: %[[EXTRACT1:.*]] = llvm.extractvalue %arg0[0] : !llvm.array<4 x vector<3xi32>>
+ // CHECK: %[[EXTRACT2:.*]] = llvm.extractvalue %arg1[0] : !llvm.array<4 x vector<3xi32>>
+ // CHECK: %[[CMP:.*]] = llvm.icmp "ult" %[[EXTRACT1]], %[[EXTRACT2]] : vector<3xi32>
+ // CHECK: %[[INSERT:.*]] = llvm.insertvalue %[[CMP]], %0[0] : !llvm.array<4 x vector<3xi1>>
+ %0 = cmpi ult, %arg0, %arg1 : vector<4x3xi32>
+ std.return
+}
More information about the Mlir-commits
mailing list