[Mlir-commits] [mlir] cdeb4a8 - [mlir] Allow lowering cmpi/cmpf with multidimensional vectors to LLVM

Benjamin Kramer llvmlistbot at llvm.org
Mon May 3 02:30:47 PDT 2021


Author: Benjamin Kramer
Date: 2021-05-03T11:30:21+02:00
New Revision: cdeb4a8a6430941563c9f04d178d4c5211164aa3

URL: https://github.com/llvm/llvm-project/commit/cdeb4a8a6430941563c9f04d178d4c5211164aa3
DIFF: https://github.com/llvm/llvm-project/commit/cdeb4a8a6430941563c9f04d178d4c5211164aa3.diff

LOG: [mlir] Allow lowering cmpi/cmpf with multidimensional vectors to LLVM

Differential Revision: https://reviews.llvm.org/D101535

Added: 
    

Modified: 
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
    mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 32ff6b526f2e1..803ea52fa717d 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -3065,11 +3065,32 @@ struct CmpIOpLowering : public ConvertOpToLLVMPattern<CmpIOp> {
   matchAndRewrite(CmpIOp cmpiOp, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
     CmpIOpAdaptor transformed(operands);
+    auto operandType = transformed.lhs().getType();
+    auto resultType = cmpiOp.getResult().getType();
 
-    rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
-        cmpiOp, typeConverter->convertType(cmpiOp.getResult().getType()),
-        convertCmpPredicate<LLVM::ICmpPredicate>(cmpiOp.getPredicate()),
-        transformed.lhs(), transformed.rhs());
+    // Handle the scalar and 1D vector cases.
+    if (!operandType.isa<LLVM::LLVMArrayType>()) {
+      rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
+          cmpiOp, typeConverter->convertType(resultType),
+          convertCmpPredicate<LLVM::ICmpPredicate>(cmpiOp.getPredicate()),
+          transformed.lhs(), transformed.rhs());
+      return success();
+    }
+
+    auto vectorType = resultType.dyn_cast<VectorType>();
+    if (!vectorType)
+      return rewriter.notifyMatchFailure(cmpiOp, "expected vector result type");
+
+    return handleMultidimensionalVectors(
+        cmpiOp.getOperation(), operands, *getTypeConverter(),
+        [&](Type llvm1DVectorTy, ValueRange operands) {
+          CmpIOpAdaptor transformed(operands);
+          return rewriter.create<LLVM::ICmpOp>(
+              cmpiOp.getLoc(), llvm1DVectorTy,
+              convertCmpPredicate<LLVM::ICmpPredicate>(cmpiOp.getPredicate()),
+              transformed.lhs(), transformed.rhs());
+        },
+        rewriter);
 
     return success();
   }
@@ -3082,13 +3103,32 @@ struct CmpFOpLowering : public ConvertOpToLLVMPattern<CmpFOp> {
   matchAndRewrite(CmpFOp cmpfOp, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
     CmpFOpAdaptor transformed(operands);
+    auto operandType = transformed.lhs().getType();
+    auto resultType = cmpfOp.getResult().getType();
 
-    rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
-        cmpfOp, typeConverter->convertType(cmpfOp.getResult().getType()),
-        convertCmpPredicate<LLVM::FCmpPredicate>(cmpfOp.getPredicate()),
-        transformed.lhs(), transformed.rhs());
+    // Handle the scalar and 1D vector cases.
+    if (!operandType.isa<LLVM::LLVMArrayType>()) {
+      rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
+          cmpfOp, typeConverter->convertType(resultType),
+          convertCmpPredicate<LLVM::FCmpPredicate>(cmpfOp.getPredicate()),
+          transformed.lhs(), transformed.rhs());
+      return success();
+    }
 
-    return success();
+    auto vectorType = resultType.dyn_cast<VectorType>();
+    if (!vectorType)
+      return rewriter.notifyMatchFailure(cmpfOp, "expected vector result type");
+
+    return handleMultidimensionalVectors(
+        cmpfOp.getOperation(), operands, *getTypeConverter(),
+        [&](Type llvm1DVectorTy, ValueRange operands) {
+          CmpFOpAdaptor transformed(operands);
+          return rewriter.create<LLVM::FCmpOp>(
+              cmpfOp.getLoc(), llvm1DVectorTy,
+              convertCmpPredicate<LLVM::FCmpPredicate>(cmpfOp.getPredicate()),
+              transformed.lhs(), transformed.rhs());
+        },
+        rewriter);
   }
 };
 

diff  --git a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
index b53de92935fe3..87f0b01c52100 100644
--- a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
@@ -274,3 +274,26 @@ func @index_vector(%arg0: vector<4xindex>) {
   std.return
 }
 
+// -----
+
+// CHECK-LABEL: func @cmpf_2dvector(
+func @cmpf_2dvector(%arg0 : vector<4x3xf32>, %arg1 : vector<4x3xf32>) {
+  // CHECK: %[[EXTRACT1:.*]] = llvm.extractvalue %arg0[0] : !llvm.array<4 x vector<3xf32>>
+  // CHECK: %[[EXTRACT2:.*]] = llvm.extractvalue %arg1[0] : !llvm.array<4 x vector<3xf32>>
+  // CHECK: %[[CMP:.*]] = llvm.fcmp "olt" %[[EXTRACT1]], %[[EXTRACT2]] : vector<3xf32>
+  // CHECK: %[[INSERT:.*]] = llvm.insertvalue %[[CMP]], %0[0] : !llvm.array<4 x vector<3xi1>>
+  %0 = cmpf olt, %arg0, %arg1 : vector<4x3xf32>
+  std.return
+}
+
+// -----
+
+// CHECK-LABEL: func @cmpi_2dvector(
+func @cmpi_2dvector(%arg0 : vector<4x3xi32>, %arg1 : vector<4x3xi32>) {
+  // CHECK: %[[EXTRACT1:.*]] = llvm.extractvalue %arg0[0] : !llvm.array<4 x vector<3xi32>>
+  // CHECK: %[[EXTRACT2:.*]] = llvm.extractvalue %arg1[0] : !llvm.array<4 x vector<3xi32>>
+  // CHECK: %[[CMP:.*]] = llvm.icmp "ult" %[[EXTRACT1]], %[[EXTRACT2]] : vector<3xi32>
+  // CHECK: %[[INSERT:.*]] = llvm.insertvalue %[[CMP]], %0[0] : !llvm.array<4 x vector<3xi1>>
+  %0 = cmpi ult, %arg0, %arg1 : vector<4x3xi32>
+  std.return
+}


        


More information about the Mlir-commits mailing list