[Mlir-commits] [mlir] 95c4e51 - [mlir][spirv] Add arith.addi_carry to spv.IAddCarry conversion
Jakub Kuderski
llvmlistbot at llvm.org
Wed Aug 17 18:34:01 PDT 2022
Author: Jakub Kuderski
Date: 2022-08-17T21:33:34-04:00
New Revision: 95c4e518393cbb0d6ed2c615c08347960995c48a
URL: https://github.com/llvm/llvm-project/commit/95c4e518393cbb0d6ed2c615c08347960995c48a
DIFF: https://github.com/llvm/llvm-project/commit/95c4e518393cbb0d6ed2c615c08347960995c48a.diff
LOG: [mlir][spirv] Add arith.addi_carry to spv.IAddCarry conversion
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D131908
Added:
Modified:
mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
index 52ab62c85dc56..56a241cd45122 100644
--- a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
@@ -13,8 +13,11 @@
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "llvm/ADT/ArrayRef.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "arith-to-spirv-pattern"
@@ -192,6 +195,15 @@ class CmpFOpNanNonePattern final : public OpConversionPattern<arith::CmpFOp> {
ConversionPatternRewriter &rewriter) const override;
};
+/// Converts arith.addi_carry to spv.IAddCarry.
+class AddICarryOpPattern final : public OpConversionPattern<arith::AddICarryOp> {
+public:
+ using OpConversionPattern<arith::AddICarryOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(arith::AddICarryOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
/// Converts arith.select to spv.Select.
class SelectOpPattern final : public OpConversionPattern<arith::SelectOp> {
public:
@@ -833,6 +845,34 @@ LogicalResult CmpFOpNanNonePattern::matchAndRewrite(
return success();
}
+//===----------------------------------------------------------------------===//
+// AddICarryOpPattern
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+AddICarryOpPattern::matchAndRewrite(arith::AddICarryOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ Type dstElemTy = adaptor.getLhs().getType();
+ auto resultTy = spirv::StructType::get({dstElemTy, dstElemTy});
+
+ Location loc = op->getLoc();
+ Value result = rewriter.create<spirv::IAddCarryOp>(
+ loc, resultTy, adaptor.getLhs(), adaptor.getRhs());
+
+ Value sumResult = rewriter.create<spirv::CompositeExtractOp>(
+ loc, result, llvm::makeArrayRef(0));
+ Value carryValue = rewriter.create<spirv::CompositeExtractOp>(
+ loc, result, llvm::makeArrayRef(1));
+
+ // Convert the carry value to boolean.
+ Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter);
+ Value carryResult =
+ rewriter.create<spirv::IEqualOp>(loc, carryValue, one);
+
+ rewriter.replaceOp(op, {sumResult, carryResult});
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// SelectOpPattern
//===----------------------------------------------------------------------===//
@@ -887,7 +927,7 @@ void mlir::arith::populateArithmeticToSPIRVPatterns(
TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
CmpIOpBooleanPattern, CmpIOpPattern,
CmpFOpNanNonePattern, CmpFOpPattern,
- SelectOpPattern,
+ AddICarryOpPattern, SelectOpPattern,
spirv::ElementwiseOpPattern<arith::MaxFOp, spirv::GLFMaxOp>,
spirv::ElementwiseOpPattern<arith::MaxSIOp, spirv::GLSMaxOp>,
diff --git a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
index 6b8cba22a0517..ca48648b7c1dd 100644
--- a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
@@ -72,6 +72,33 @@ func.func @index_scalar_srem(%lhs: index, %rhs: index) {
return
}
+// Check integer add-with-carry conversions.
+// CHECK-LABEL: @int32_scalar_addi_carry
+// CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32)
+func.func @int32_scalar_addi_carry(%lhs: i32, %rhs: i32) -> (i32, i1) {
+ // CHECK-NEXT: %[[IAC:.+]] = spv.IAddCarry %[[LHS]], %[[RHS]] : !spv.struct<(i32, i32)>
+ // CHECK-DAG: %[[SUM:.+]] = spv.CompositeExtract %[[IAC]][0 : i32] : !spv.struct<(i32, i32)>
+ // CHECK-DAG: %[[C0:.+]] = spv.CompositeExtract %[[IAC]][1 : i32] : !spv.struct<(i32, i32)>
+ // CHECK-DAG: %[[ONE:.+]] = spv.Constant 1 : i32
+ // CHECK-NEXT: %[[C1:.+]] = spv.IEqual %[[C0]], %[[ONE]] : i32
+ // CHECK-NEXT: return %[[SUM]], %[[C1]] : i32, i1
+ %sum, %carry = arith.addi_carry %lhs, %rhs: i32, i1
+ return %sum, %carry : i32, i1
+}
+
+// CHECK-LABEL: @int32_vector_addi_carry
+// CHECK-SAME: (%[[LHS:.+]]: vector<4xi32>, %[[RHS:.+]]: vector<4xi32>)
+func.func @int32_vector_addi_carry(%lhs: vector<4xi32>, %rhs: vector<4xi32>) -> (vector<4xi32>, vector<4xi1>) {
+ // CHECK-NEXT: %[[IAC:.+]] = spv.IAddCarry %[[LHS]], %[[RHS]] : !spv.struct<(vector<4xi32>, vector<4xi32>)>
+ // CHECK-DAG: %[[SUM:.+]] = spv.CompositeExtract %[[IAC]][0 : i32] : !spv.struct<(vector<4xi32>, vector<4xi32>)>
+ // CHECK-DAG: %[[C0:.+]] = spv.CompositeExtract %[[IAC]][1 : i32] : !spv.struct<(vector<4xi32>, vector<4xi32>)>
+ // CHECK-DAG: %[[ONE:.+]] = spv.Constant dense<1> : vector<4xi32>
+ // CHECK-NEXT: %[[C1:.+]] = spv.IEqual %[[C0]], %[[ONE]] : vector<4xi32>
+ // CHECK-NEXT: return %[[SUM]], %[[C1]] : vector<4xi32>, vector<4xi1>
+ %sum, %carry = arith.addi_carry %lhs, %rhs: vector<4xi32>, vector<4xi1>
+ return %sum, %carry : vector<4xi32>, vector<4xi1>
+}
+
// Check float unary operation conversions.
// CHECK-LABEL: @float32_unary_scalar
func.func @float32_unary_scalar(%arg0: f32) {
More information about the Mlir-commits
mailing list