[Mlir-commits] [mlir] 71db971 - [mlir][emitc] Arith to EmitC: Handle addi, subi and muli (#86120)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Mar 22 07:39:57 PDT 2024
Author: Matthias Gehre
Date: 2024-03-22T15:39:52+01:00
New Revision: 71db97152173a524a3e16e02b7fdc50f405c8695
URL: https://github.com/llvm/llvm-project/commit/71db97152173a524a3e16e02b7fdc50f405c8695
DIFF: https://github.com/llvm/llvm-project/commit/71db97152173a524a3e16e02b7fdc50f405c8695.diff
LOG: [mlir][emitc] Arith to EmitC: Handle addi, subi and muli (#86120)
Important to consider that `arith` has wrap around semantics, and in C++
signed overflow is UB.
Unless the operation guarantees that no signed overflow happens, we will
perform the arithmetic in an equivalent unsigned type.
`bool` also doesn't wrap around in C++, and is not addressed here.
Added:
mlir/test/Conversion/ArithToEmitC/arith-to-emitc-failed.mlir
Modified:
mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index 3532785c31b939..db493c1294ba2d 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -55,6 +55,55 @@ class ArithOpConversion final : public OpConversionPattern<ArithOp> {
}
};
+template <typename ArithOp, typename EmitCOp>
+class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
+public:
+ using OpConversionPattern<ArithOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ Type type = this->getTypeConverter()->convertType(op.getType());
+ if (!isa_and_nonnull<IntegerType, IndexType>(type)) {
+ return rewriter.notifyMatchFailure(op, "expected integer type");
+ }
+
+ if (type.isInteger(1)) {
+ // arith expects wrap-around arithmethic, which doesn't happen on `bool`.
+ return rewriter.notifyMatchFailure(op, "i1 type is not implemented");
+ }
+
+ Value lhs = adaptor.getLhs();
+ Value rhs = adaptor.getRhs();
+ Type arithmeticType = type;
+ if ((type.isSignlessInteger() || type.isSignedInteger()) &&
+ !bitEnumContainsAll(op.getOverflowFlags(),
+ arith::IntegerOverflowFlags::nsw)) {
+ // If the C type is signed and the op doesn't guarantee "No Signed Wrap",
+ // we compute in unsigned integers to avoid UB.
+ arithmeticType = rewriter.getIntegerType(type.getIntOrFloatBitWidth(),
+ /*isSigned=*/false);
+ }
+ if (arithmeticType != type) {
+ lhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
+ lhs);
+ rhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
+ rhs);
+ }
+
+ Value result = rewriter.template create<EmitCOp>(op.getLoc(),
+ arithmeticType, lhs, rhs);
+
+ if (arithmeticType != type) {
+ result =
+ rewriter.template create<emitc::CastOp>(op.getLoc(), type, result);
+ }
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
class SelectOpConversion : public OpConversionPattern<arith::SelectOp> {
public:
using OpConversionPattern<arith::SelectOp>::OpConversionPattern;
@@ -96,6 +145,9 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
ArithOpConversion<arith::DivFOp, emitc::DivOp>,
ArithOpConversion<arith::MulFOp, emitc::MulOp>,
ArithOpConversion<arith::SubFOp, emitc::SubOp>,
+ IntegerOpConversion<arith::AddIOp, emitc::AddOp>,
+ IntegerOpConversion<arith::MulIOp, emitc::MulOp>,
+ IntegerOpConversion<arith::SubIOp, emitc::SubOp>,
SelectOpConversion
>(typeConverter, ctx);
// clang-format on
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-failed.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-failed.mlir
new file mode 100644
index 00000000000000..30abd81f3d4470
--- /dev/null
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-failed.mlir
@@ -0,0 +1,15 @@
+// RUN: mlir-opt -convert-arith-to-emitc %s -split-input-file -verify-diagnostics
+
+func.func @bool(%arg0: i1, %arg1: i1) {
+ // expected-error at +1 {{failed to legalize operation 'arith.addi'}}
+ %0 = arith.addi %arg0, %arg1 : i1
+ return
+}
+
+// -----
+
+func.func @vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) {
+ // expected-error at +1 {{failed to legalize operation 'arith.addi'}}
+ %0 = arith.addi %arg0, %arg1 : vector<4xi32>
+ return
+}
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index 022530ef4db84b..76ba518577ab8e 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -37,6 +37,57 @@ func.func @arith_ops(%arg0: f32, %arg1: f32) {
// -----
+// CHECK-LABEL: arith_integer_ops
+func.func @arith_integer_ops(%arg0: i32, %arg1: i32) {
+ // CHECK: %[[C1:[^ ]*]] = emitc.cast %arg0 : i32 to ui32
+ // CHECK: %[[C2:[^ ]*]] = emitc.cast %arg1 : i32 to ui32
+ // CHECK: %[[ADD:[^ ]*]] = emitc.add %[[C1]], %[[C2]] : (ui32, ui32) -> ui32
+ // CHECK: %[[C3:[^ ]*]] = emitc.cast %[[ADD]] : ui32 to i32
+ %0 = arith.addi %arg0, %arg1 : i32
+ // CHECK: %[[C1:[^ ]*]] = emitc.cast %arg0 : i32 to ui32
+ // CHECK: %[[C2:[^ ]*]] = emitc.cast %arg1 : i32 to ui32
+ // CHECK: %[[SUB:[^ ]*]] = emitc.sub %[[C1]], %[[C2]] : (ui32, ui32) -> ui32
+ // CHECK: %[[C3:[^ ]*]] = emitc.cast %[[SUB]] : ui32 to i32
+ %1 = arith.subi %arg0, %arg1 : i32
+ // CHECK: %[[C1:[^ ]*]] = emitc.cast %arg0 : i32 to ui32
+ // CHECK: %[[C2:[^ ]*]] = emitc.cast %arg1 : i32 to ui32
+ // CHECK: %[[MUL:[^ ]*]] = emitc.mul %[[C1]], %[[C2]] : (ui32, ui32) -> ui32
+ // CHECK: %[[C3:[^ ]*]] = emitc.cast %[[MUL]] : ui32 to i32
+ %2 = arith.muli %arg0, %arg1 : i32
+
+ return
+}
+
+// -----
+
+// CHECK-LABEL: arith_integer_ops_signed_nsw
+func.func @arith_integer_ops_signed_nsw(%arg0: i32, %arg1: i32) {
+ // CHECK: emitc.add %arg0, %arg1 : (i32, i32) -> i32
+ %0 = arith.addi %arg0, %arg1 overflow<nsw> : i32
+ // CHECK: emitc.sub %arg0, %arg1 : (i32, i32) -> i32
+ %1 = arith.subi %arg0, %arg1 overflow<nsw> : i32
+ // CHECK: emitc.mul %arg0, %arg1 : (i32, i32) -> i32
+ %2 = arith.muli %arg0, %arg1 overflow<nsw> : i32
+
+ return
+}
+
+// -----
+
+// CHECK-LABEL: arith_index
+func.func @arith_index(%arg0: index, %arg1: index) {
+ // CHECK: emitc.add %arg0, %arg1 : (index, index) -> index
+ %0 = arith.addi %arg0, %arg1 : index
+ // CHECK: emitc.sub %arg0, %arg1 : (index, index) -> index
+ %1 = arith.subi %arg0, %arg1 : index
+ // CHECK: emitc.mul %arg0, %arg1 : (index, index) -> index
+ %2 = arith.muli %arg0, %arg1 : index
+
+ return
+}
+
+// -----
+
func.func @arith_select(%arg0: i1, %arg1: tensor<8xi32>, %arg2: tensor<8xi32>) -> () {
// CHECK: [[V0:[^ ]*]] = emitc.conditional %arg0, %arg1, %arg2 : tensor<8xi32>
%0 = arith.select %arg0, %arg1, %arg2 : i1, tensor<8xi32>
More information about the Mlir-commits
mailing list