[Mlir-commits] [mlir] f7bbe09 - [mlir] Add arith.addui_carry conversion to LLVM

Jakub Kuderski llvmlistbot at llvm.org
Thu Aug 25 08:10:11 PDT 2022


Author: Jakub Kuderski
Date: 2022-08-25T11:09:22-04:00
New Revision: f7bbe099c996aa38ceee701eea9bd81cfe6066e3

URL: https://github.com/llvm/llvm-project/commit/f7bbe099c996aa38ceee701eea9bd81cfe6066e3
DIFF: https://github.com/llvm/llvm-project/commit/f7bbe099c996aa38ceee701eea9bd81cfe6066e3.diff

LOG: [mlir] Add arith.addui_carry conversion to LLVM

This covers the scalar and 1-D vector case.

I haven't implemented conversion for the multidimensional vector case yet because
the current LLVM conversion infrastructure (`handleMultidimensionalVectors`) does
not seem to support ops with multiple results.

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D132613

Added: 
    

Modified: 
    mlir/lib/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.cpp
    mlir/test/Conversion/ArithmeticToLLVM/arith-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.cpp b/mlir/lib/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.cpp
index a6e1b43fa34eb..df6df155f41d8 100644
--- a/mlir/lib/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.cpp
@@ -106,6 +106,15 @@ struct IndexCastOpLowering : public ConvertOpToLLVMPattern<arith::IndexCastOp> {
                   ConversionPatternRewriter &rewriter) const override;
 };
 
+struct AddUICarryOpLowering
+    : public ConvertOpToLLVMPattern<arith::AddUICarryOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(arith::AddUICarryOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
 struct CmpIOpLowering : public ConvertOpToLLVMPattern<arith::CmpIOp> {
   using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
 
@@ -186,6 +195,45 @@ LogicalResult IndexCastOpLowering::matchAndRewrite(
       rewriter);
 }
 
+//===----------------------------------------------------------------------===//
+// AddUICarryOpLowering
+//===----------------------------------------------------------------------===//
+
+LogicalResult AddUICarryOpLowering::matchAndRewrite(
+    arith::AddUICarryOp op, OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+  Type operandType = adaptor.getLhs().getType();
+  Type sumResultType = op.getSum().getType();
+  Type carryResultType = op.getCarry().getType();
+
+  if (!LLVM::isCompatibleType(operandType))
+    return failure();
+
+  MLIRContext *ctx = rewriter.getContext();
+  Location loc = op.getLoc();
+
+  // Handle the scalar and 1D vector cases.
+  if (!operandType.isa<LLVM::LLVMArrayType>()) {
+    Type newCarryType = typeConverter->convertType(carryResultType);
+    Type structType =
+        LLVM::LLVMStructType::getLiteral(ctx, {sumResultType, newCarryType});
+    Value addOverflow = rewriter.create<LLVM::UAddWithOverflowOp>(
+        loc, structType, adaptor.getLhs(), adaptor.getRhs());
+    Value sumExtracted =
+        rewriter.create<LLVM::ExtractValueOp>(loc, addOverflow, 0);
+    Value carryExtracted =
+        rewriter.create<LLVM::ExtractValueOp>(loc, addOverflow, 1);
+    rewriter.replaceOp(op, {sumExtracted, carryExtracted});
+    return success();
+  }
+
+  if (!sumResultType.isa<VectorType>())
+    return rewriter.notifyMatchFailure(loc, "expected vector result types");
+
+  return rewriter.notifyMatchFailure(loc,
+                                     "ND vector types are not supported yet");
+}
+
 //===----------------------------------------------------------------------===//
 // CmpIOpLowering
 //===----------------------------------------------------------------------===//
@@ -300,6 +348,7 @@ void mlir::arith::populateArithmeticToLLVMConversionPatterns(
     AddFOpLowering,
     AddIOpLowering,
     AndIOpLowering,
+    AddUICarryOpLowering,
     BitcastOpLowering,
     ConstantOpLowering,
     CmpFOpLowering,

diff  --git a/mlir/test/Conversion/ArithmeticToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithmeticToLLVM/arith-to-llvm.mlir
index 24664cfddb7a6..c476d43627275 100644
--- a/mlir/test/Conversion/ArithmeticToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithmeticToLLVM/arith-to-llvm.mlir
@@ -338,6 +338,30 @@ func.func @bitcast_1d(%arg0: vector<2xf32>) {
 
 // -----
 
+// CHECK-LABEL: @addui_carry_scalar
+// CHECK-SAME:    ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32) -> (i32, i1)
+func.func @addui_carry_scalar(%arg0: i32, %arg1: i32) -> (i32, i1) {
+  // CHECK-NEXT: [[RES:%.+]] = "llvm.intr.uadd.with.overflow"([[ARG0]], [[ARG1]]) : (i32, i32) -> !llvm.struct<(i32, i1)>
+  // CHECK-NEXT: [[SUM:%.+]] = llvm.extractvalue [[RES]][0] : !llvm.struct<(i32, i1)>
+  // CHECK-NEXT: [[CARRY:%.+]] = llvm.extractvalue [[RES]][1] : !llvm.struct<(i32, i1)>
+  %sum, %carry = arith.addui_carry %arg0, %arg1 : i32, i1
+  // CHECK-NEXT: return [[SUM]], [[CARRY]] : i32, i1
+  return %sum, %carry : i32, i1
+}
+
+// CHECK-LABEL: @addui_carry_vector1d
+// CHECK-SAME:    ([[ARG0:%.+]]: vector<3xi16>, [[ARG1:%.+]]: vector<3xi16>) -> (vector<3xi16>, vector<3xi1>)
+func.func @addui_carry_vector1d(%arg0: vector<3xi16>, %arg1: vector<3xi16>) -> (vector<3xi16>, vector<3xi1>) {
+  // CHECK-NEXT: [[RES:%.+]] = "llvm.intr.uadd.with.overflow"([[ARG0]], [[ARG1]]) : (vector<3xi16>, vector<3xi16>) -> !llvm.struct<(vector<3xi16>, vector<3xi1>)>
+  // CHECK-NEXT: [[SUM:%.+]] = llvm.extractvalue [[RES]][0] : !llvm.struct<(vector<3xi16>, vector<3xi1>)>
+  // CHECK-NEXT: [[CARRY:%.+]] = llvm.extractvalue [[RES]][1] : !llvm.struct<(vector<3xi16>, vector<3xi1>)>
+  %sum, %carry = arith.addui_carry %arg0, %arg1 : vector<3xi16>, vector<3xi1>
+  // CHECK-NEXT: return [[SUM]], [[CARRY]] : vector<3xi16>, vector<3xi1>
+  return %sum, %carry : vector<3xi16>, vector<3xi1>
+}
+
+// -----
+
 // CHECK-LABEL: func @cmpf_2dvector(
 func.func @cmpf_2dvector(%arg0 : vector<4x3xf32>, %arg1 : vector<4x3xf32>) {
   // CHECK: %[[ARG0:.*]] = builtin.unrealized_conversion_cast


        


More information about the Mlir-commits mailing list