[Mlir-commits] [mlir] [mlir][emitc] Arith to EmitC: Handle addi, subi and muli (PR #86120)

Matthias Gehre llvmlistbot at llvm.org
Fri Mar 22 06:45:28 PDT 2024


https://github.com/mgehre-amd updated https://github.com/llvm/llvm-project/pull/86120

>From d509738aa5298535d1e9075d22f71eb8a34f903a Mon Sep 17 00:00:00 2001
From: Matthias Gehre <matthias.gehre at amd.com>
Date: Thu, 14 Mar 2024 21:24:22 +0100
Subject: [PATCH 1/3] [mlir][emitc] Arith to EmitC: Handle addi, subi and muli

No handling yet for divsi and divui, as they require
special considerations for signedness.
---
 mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp   |  3 +++
 .../Conversion/ArithToEmitC/arith-to-emitc.mlir     | 13 +++++++++++++
 2 files changed, 16 insertions(+)

diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index 3532785c31b939..e85bb0f6b227b9 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -96,6 +96,9 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
     ArithOpConversion<arith::DivFOp, emitc::DivOp>,
     ArithOpConversion<arith::MulFOp, emitc::MulOp>,
     ArithOpConversion<arith::SubFOp, emitc::SubOp>,
+    ArithOpConversion<arith::AddIOp, emitc::AddOp>,
+    ArithOpConversion<arith::MulIOp, emitc::MulOp>,
+    ArithOpConversion<arith::SubIOp, emitc::SubOp>,
     SelectOpConversion
   >(typeConverter, ctx);
   // clang-format on
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index 022530ef4db84b..e5f2c330b851c3 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -37,6 +37,19 @@ func.func @arith_ops(%arg0: f32, %arg1: f32) {
 
 // -----
 
+func.func @arith_integer_ops(%arg0: i32, %arg1: i32) {
+  // CHECK: emitc.add %arg0, %arg1 : (i32, i32) -> i32
+  %0 = arith.addi %arg0, %arg1 : i32
+  // CHECK: emitc.sub %arg0, %arg1 : (i32, i32) -> i32
+  %1 = arith.subi %arg0, %arg1 : i32
+  // CHECK: emitc.mul %arg0, %arg1 : (i32, i32) -> i32
+  %2 = arith.muli %arg0, %arg1 : i32
+
+  return
+}
+
+// -----
+
 func.func @arith_select(%arg0: i1, %arg1: tensor<8xi32>, %arg2: tensor<8xi32>) -> () {
   // CHECK: [[V0:[^ ]*]] = emitc.conditional %arg0, %arg1, %arg2 : tensor<8xi32>
   %0 = arith.select %arg0, %arg1, %arg2 : i1, tensor<8xi32>

>From b4e0e1cc9f403952f6952f35d4eae43fd60d37a9 Mon Sep 17 00:00:00 2001
From: Matthias Gehre <matthias.gehre at amd.com>
Date: Thu, 21 Mar 2024 23:02:11 +0100
Subject: [PATCH 2/3] Avoid UB due to signed wrap around

---
 .../Conversion/ArithToEmitC/ArithToEmitC.cpp  | 50 +++++++++++++++++--
 .../ArithToEmitC/arith-to-emitc.mlir          | 44 ++++++++++++++--
 2 files changed, 88 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index e85bb0f6b227b9..280adc5bd6270d 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -55,6 +55,50 @@ class ArithOpConversion final : public OpConversionPattern<ArithOp> {
   }
 };
 
+template <typename ArithOp, typename EmitCOp>
+class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
+public:
+  using OpConversionPattern<ArithOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    Type type = this->getTypeConverter()->convertType(op.getType());
+    if (!isa_and_nonnull<IntegerType, IndexType>(type)) {
+      return rewriter.notifyMatchFailure(op, "expected integer type");
+    }
+
+    Value lhs = adaptor.getLhs();
+    Value rhs = adaptor.getRhs();
+    Type arithmeticType = type;
+    if ((type.isSignlessInteger() || type.isSignedInteger()) &&
+        !bitEnumContainsAll(op.getOverflowFlags(),
+                            arith::IntegerOverflowFlags::nsw)) {
+      // If the C type is signed and the op doesn't guarantee "No Signed Wrap",
+      // we compute in unsigned integers to avoid UB.
+      arithmeticType = rewriter.getIntegerType(type.getIntOrFloatBitWidth(),
+                                               /*isSigned=*/false);
+    }
+    if (arithmeticType != type) {
+      lhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
+                                                    lhs);
+      rhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
+                                                    rhs);
+    }
+
+    Value result = rewriter.template create<EmitCOp>(op.getLoc(),
+                                                     arithmeticType, lhs, rhs);
+
+    if (arithmeticType != type) {
+      result =
+          rewriter.template create<emitc::CastOp>(op.getLoc(), type, result);
+    }
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
 class SelectOpConversion : public OpConversionPattern<arith::SelectOp> {
 public:
   using OpConversionPattern<arith::SelectOp>::OpConversionPattern;
@@ -96,9 +140,9 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
     ArithOpConversion<arith::DivFOp, emitc::DivOp>,
     ArithOpConversion<arith::MulFOp, emitc::MulOp>,
     ArithOpConversion<arith::SubFOp, emitc::SubOp>,
-    ArithOpConversion<arith::AddIOp, emitc::AddOp>,
-    ArithOpConversion<arith::MulIOp, emitc::MulOp>,
-    ArithOpConversion<arith::SubIOp, emitc::SubOp>,
+    IntegerOpConversion<arith::AddIOp, emitc::AddOp>,
+    IntegerOpConversion<arith::MulIOp, emitc::MulOp>,
+    IntegerOpConversion<arith::SubIOp, emitc::SubOp>,
     SelectOpConversion
   >(typeConverter, ctx);
   // clang-format on
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index e5f2c330b851c3..499c5f7397e112 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -37,12 +37,22 @@ func.func @arith_ops(%arg0: f32, %arg1: f32) {
 
 // -----
 
+// CHECK-LABEL: arith_integer_ops
 func.func @arith_integer_ops(%arg0: i32, %arg1: i32) {
-  // CHECK: emitc.add %arg0, %arg1 : (i32, i32) -> i32
+  // CHECK: %[[C1:[^ ]*]] = emitc.cast %arg0 : i32 to ui3
+  // CHECK: %[[C2:[^ ]*]] = emitc.cast %arg1 : i32 to ui3
+  // CHECK: %[[ADD:[^ ]*]] = emitc.add %[[C1]], %[[C2]] : (ui32, ui32) -> ui32
+  // CHECK: %[[C3:[^ ]*]] = emitc.cast %[[ADD]] : ui32 to i3
   %0 = arith.addi %arg0, %arg1 : i32
-  // CHECK: emitc.sub %arg0, %arg1 : (i32, i32) -> i32
+  // CHECK: %[[C1:[^ ]*]] = emitc.cast %arg0 : i32 to ui3
+  // CHECK: %[[C2:[^ ]*]] = emitc.cast %arg1 : i32 to ui3
+  // CHECK: %[[SUB:[^ ]*]] = emitc.sub %[[C1]], %[[C2]] : (ui32, ui32) -> ui32
+  // CHECK: %[[C3:[^ ]*]] = emitc.cast %[[SUB]] : ui32 to i3
   %1 = arith.subi %arg0, %arg1 : i32
-  // CHECK: emitc.mul %arg0, %arg1 : (i32, i32) -> i32
+  // CHECK: %[[C1:[^ ]*]] = emitc.cast %arg0 : i32 to ui3
+  // CHECK: %[[C2:[^ ]*]] = emitc.cast %arg1 : i32 to ui3
+  // CHECK: %[[MUL:[^ ]*]] = emitc.mul %[[C1]], %[[C2]] : (ui32, ui32) -> ui32
+  // CHECK: %[[C3:[^ ]*]] = emitc.cast %[[MUL]] : ui32 to i3
   %2 = arith.muli %arg0, %arg1 : i32
 
   return
@@ -50,6 +60,34 @@ func.func @arith_integer_ops(%arg0: i32, %arg1: i32) {
 
 // -----
 
+// CHECK-LABEL: arith_integer_ops_signed_nsw
+func.func @arith_integer_ops_signed_nsw(%arg0: i32, %arg1: i32) {
+  // CHECK: emitc.add %arg0, %arg1 : (i32, i32) -> i32
+  %0 = arith.addi %arg0, %arg1 overflow<nsw> : i32
+  // CHECK: emitc.sub %arg0, %arg1 : (i32, i32) -> i32
+  %1 = arith.subi %arg0, %arg1 overflow<nsw>  : i32
+  // CHECK: emitc.mul %arg0, %arg1 : (i32, i32) -> i32
+  %2 = arith.muli %arg0, %arg1 overflow<nsw> : i32
+
+  return
+}
+
+// -----
+
+// CHECK-LABEL: arith_index
+func.func @arith_index(%arg0: index, %arg1: index) {
+  // CHECK: emitc.add %arg0, %arg1 : (index, index) -> index
+  %0 = arith.addi %arg0, %arg1 : index
+  // CHECK: emitc.sub %arg0, %arg1 : (index, index) -> index
+  %1 = arith.subi %arg0, %arg1 : index
+  // CHECK: emitc.mul %arg0, %arg1 : (index, index) -> index
+  %2 = arith.muli %arg0, %arg1 : index
+
+  return
+}
+
+// -----
+
 func.func @arith_select(%arg0: i1, %arg1: tensor<8xi32>, %arg2: tensor<8xi32>) -> () {
   // CHECK: [[V0:[^ ]*]] = emitc.conditional %arg0, %arg1, %arg2 : tensor<8xi32>
   %0 = arith.select %arg0, %arg1, %arg2 : i1, tensor<8xi32>

>From ec3e3ae80ab94e4db5669b84ff4863ad1be5b817 Mon Sep 17 00:00:00 2001
From: Matthias Gehre <matthias.gehre at amd.com>
Date: Fri, 22 Mar 2024 14:43:58 +0100
Subject: [PATCH 3/3] Fix truncated check lines; exclude bool

---
 .../Conversion/ArithToEmitC/ArithToEmitC.cpp   |  5 +++++
 .../ArithToEmitC/arith-to-emitc-failed.mlir    | 15 +++++++++++++++
 .../ArithToEmitC/arith-to-emitc.mlir           | 18 +++++++++---------
 3 files changed, 29 insertions(+), 9 deletions(-)
 create mode 100644 mlir/test/Conversion/ArithToEmitC/arith-to-emitc-failed.mlir

diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index 280adc5bd6270d..db493c1294ba2d 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -69,6 +69,11 @@ class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
       return rewriter.notifyMatchFailure(op, "expected integer type");
     }
 
+    if (type.isInteger(1)) {
+      // arith expects wrap-around arithmethic, which doesn't happen on `bool`.
+      return rewriter.notifyMatchFailure(op, "i1 type is not implemented");
+    }
+
     Value lhs = adaptor.getLhs();
     Value rhs = adaptor.getRhs();
     Type arithmeticType = type;
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-failed.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-failed.mlir
new file mode 100644
index 00000000000000..a68344d9249715
--- /dev/null
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-failed.mlir
@@ -0,0 +1,15 @@
+// RUN: mlir-opt -convert-arith-to-emitc %s -split-input-file -verify-diagnostics
+
+func.func @bool(%arg0 : i1, %arg1 : i1) {
+  // expected-error at +1 {{failed to legalize operation 'arith.addi'}}
+  %0 = arith.addi %arg0, %arg1 : i1
+  return
+}
+
+// -----
+
+func.func @vector(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) {
+  // expected-error at +1 {{failed to legalize operation 'arith.addi'}}
+  %0 = arith.addi %arg0, %arg1 : vector<4xi32>
+  return
+}
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index 499c5f7397e112..76ba518577ab8e 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -39,20 +39,20 @@ func.func @arith_ops(%arg0: f32, %arg1: f32) {
 
 // CHECK-LABEL: arith_integer_ops
 func.func @arith_integer_ops(%arg0: i32, %arg1: i32) {
-  // CHECK: %[[C1:[^ ]*]] = emitc.cast %arg0 : i32 to ui3
-  // CHECK: %[[C2:[^ ]*]] = emitc.cast %arg1 : i32 to ui3
+  // CHECK: %[[C1:[^ ]*]] = emitc.cast %arg0 : i32 to ui32
+  // CHECK: %[[C2:[^ ]*]] = emitc.cast %arg1 : i32 to ui32
   // CHECK: %[[ADD:[^ ]*]] = emitc.add %[[C1]], %[[C2]] : (ui32, ui32) -> ui32
-  // CHECK: %[[C3:[^ ]*]] = emitc.cast %[[ADD]] : ui32 to i3
+  // CHECK: %[[C3:[^ ]*]] = emitc.cast %[[ADD]] : ui32 to i32
   %0 = arith.addi %arg0, %arg1 : i32
-  // CHECK: %[[C1:[^ ]*]] = emitc.cast %arg0 : i32 to ui3
-  // CHECK: %[[C2:[^ ]*]] = emitc.cast %arg1 : i32 to ui3
+  // CHECK: %[[C1:[^ ]*]] = emitc.cast %arg0 : i32 to ui32
+  // CHECK: %[[C2:[^ ]*]] = emitc.cast %arg1 : i32 to ui32
   // CHECK: %[[SUB:[^ ]*]] = emitc.sub %[[C1]], %[[C2]] : (ui32, ui32) -> ui32
-  // CHECK: %[[C3:[^ ]*]] = emitc.cast %[[SUB]] : ui32 to i3
+  // CHECK: %[[C3:[^ ]*]] = emitc.cast %[[SUB]] : ui32 to i32
   %1 = arith.subi %arg0, %arg1 : i32
-  // CHECK: %[[C1:[^ ]*]] = emitc.cast %arg0 : i32 to ui3
-  // CHECK: %[[C2:[^ ]*]] = emitc.cast %arg1 : i32 to ui3
+  // CHECK: %[[C1:[^ ]*]] = emitc.cast %arg0 : i32 to ui32
+  // CHECK: %[[C2:[^ ]*]] = emitc.cast %arg1 : i32 to ui32
   // CHECK: %[[MUL:[^ ]*]] = emitc.mul %[[C1]], %[[C2]] : (ui32, ui32) -> ui32
-  // CHECK: %[[C3:[^ ]*]] = emitc.cast %[[MUL]] : ui32 to i3
+  // CHECK: %[[C3:[^ ]*]] = emitc.cast %[[MUL]] : ui32 to i32
   %2 = arith.muli %arg0, %arg1 : i32
 
   return



More information about the Mlir-commits mailing list