[Mlir-commits] [mlir] 95c4e51 - [mlir][spirv] Add arith.addi_carry to spv.IAddCarry conversion

Jakub Kuderski llvmlistbot at llvm.org
Wed Aug 17 18:34:01 PDT 2022


Author: Jakub Kuderski
Date: 2022-08-17T21:33:34-04:00
New Revision: 95c4e518393cbb0d6ed2c615c08347960995c48a

URL: https://github.com/llvm/llvm-project/commit/95c4e518393cbb0d6ed2c615c08347960995c48a
DIFF: https://github.com/llvm/llvm-project/commit/95c4e518393cbb0d6ed2c615c08347960995c48a.diff

LOG: [mlir][spirv] Add arith.addi_carry to spv.IAddCarry conversion

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D131908

Added: 
    

Modified: 
    mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
    mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
index 52ab62c85dc56..56a241cd45122 100644
--- a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
@@ -13,8 +13,11 @@
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "llvm/ADT/ArrayRef.h"
 #include "llvm/Support/Debug.h"
 
 #define DEBUG_TYPE "arith-to-spirv-pattern"
@@ -192,6 +195,15 @@ class CmpFOpNanNonePattern final : public OpConversionPattern<arith::CmpFOp> {
                   ConversionPatternRewriter &rewriter) const override;
 };
 
+/// Converts arith.addi_carry to spv.IAddCarry.
+class AddICarryOpPattern final : public OpConversionPattern<arith::AddICarryOp> {
+public:
+  using OpConversionPattern<arith::AddICarryOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(arith::AddICarryOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
 /// Converts arith.select to spv.Select.
 class SelectOpPattern final : public OpConversionPattern<arith::SelectOp> {
 public:
@@ -833,6 +845,34 @@ LogicalResult CmpFOpNanNonePattern::matchAndRewrite(
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// AddICarryOpPattern
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+AddICarryOpPattern::matchAndRewrite(arith::AddICarryOp op, OpAdaptor adaptor,
+                                    ConversionPatternRewriter &rewriter) const {
+  Type dstElemTy = adaptor.getLhs().getType();
+  auto resultTy = spirv::StructType::get({dstElemTy, dstElemTy});
+
+  Location loc = op->getLoc();
+  Value result = rewriter.create<spirv::IAddCarryOp>(
+      loc, resultTy, adaptor.getLhs(), adaptor.getRhs());
+
+  Value sumResult = rewriter.create<spirv::CompositeExtractOp>(
+      loc, result, llvm::makeArrayRef(0));
+  Value carryValue = rewriter.create<spirv::CompositeExtractOp>(
+      loc, result, llvm::makeArrayRef(1));
+
+  // Convert the carry value to boolean.
+  Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter);
+  Value carryResult =
+      rewriter.create<spirv::IEqualOp>(loc, carryValue, one);
+
+  rewriter.replaceOp(op, {sumResult, carryResult});
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // SelectOpPattern
 //===----------------------------------------------------------------------===//
@@ -887,7 +927,7 @@ void mlir::arith::populateArithmeticToSPIRVPatterns(
     TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
     CmpIOpBooleanPattern, CmpIOpPattern,
     CmpFOpNanNonePattern, CmpFOpPattern,
-    SelectOpPattern,
+    AddICarryOpPattern, SelectOpPattern,
 
     spirv::ElementwiseOpPattern<arith::MaxFOp, spirv::GLFMaxOp>,
     spirv::ElementwiseOpPattern<arith::MaxSIOp, spirv::GLSMaxOp>,

diff  --git a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
index 6b8cba22a0517..ca48648b7c1dd 100644
--- a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
@@ -72,6 +72,33 @@ func.func @index_scalar_srem(%lhs: index, %rhs: index) {
   return
 }
 
+// Check integer add-with-carry conversions.
+// CHECK-LABEL: @int32_scalar_addi_carry
+// CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32)
+func.func @int32_scalar_addi_carry(%lhs: i32, %rhs: i32) -> (i32, i1) {
+  // CHECK-NEXT: %[[IAC:.+]] = spv.IAddCarry %[[LHS]], %[[RHS]] : !spv.struct<(i32, i32)>
+  // CHECK-DAG:  %[[SUM:.+]] = spv.CompositeExtract %[[IAC]][0 : i32] : !spv.struct<(i32, i32)>
+  // CHECK-DAG:  %[[C0:.+]]  = spv.CompositeExtract %[[IAC]][1 : i32] : !spv.struct<(i32, i32)>
+  // CHECK-DAG:  %[[ONE:.+]] = spv.Constant 1 : i32
+  // CHECK-NEXT: %[[C1:.+]]  = spv.IEqual %[[C0]], %[[ONE]] : i32
+  // CHECK-NEXT: return %[[SUM]], %[[C1]] : i32, i1
+  %sum, %carry = arith.addi_carry %lhs, %rhs: i32, i1
+  return %sum, %carry : i32, i1
+}
+
+// CHECK-LABEL: @int32_vector_addi_carry
+// CHECK-SAME: (%[[LHS:.+]]: vector<4xi32>, %[[RHS:.+]]: vector<4xi32>)
+func.func @int32_vector_addi_carry(%lhs: vector<4xi32>, %rhs: vector<4xi32>) -> (vector<4xi32>, vector<4xi1>) {
+  // CHECK-NEXT: %[[IAC:.+]] = spv.IAddCarry %[[LHS]], %[[RHS]] : !spv.struct<(vector<4xi32>, vector<4xi32>)>
+  // CHECK-DAG:  %[[SUM:.+]] = spv.CompositeExtract %[[IAC]][0 : i32] : !spv.struct<(vector<4xi32>, vector<4xi32>)>
+  // CHECK-DAG:  %[[C0:.+]]  = spv.CompositeExtract %[[IAC]][1 : i32] : !spv.struct<(vector<4xi32>, vector<4xi32>)>
+  // CHECK-DAG:  %[[ONE:.+]] = spv.Constant dense<1> : vector<4xi32>
+  // CHECK-NEXT: %[[C1:.+]]  = spv.IEqual %[[C0]], %[[ONE]] : vector<4xi32>
+  // CHECK-NEXT: return %[[SUM]], %[[C1]] : vector<4xi32>, vector<4xi1>
+  %sum, %carry = arith.addi_carry %lhs, %rhs: vector<4xi32>, vector<4xi1>
+  return %sum, %carry : vector<4xi32>, vector<4xi1>
+}
+
 // Check float unary operation conversions.
 // CHECK-LABEL: @float32_unary_scalar
 func.func @float32_unary_scalar(%arg0: f32) {


        


More information about the Mlir-commits mailing list