[Mlir-commits] [mlir] [mlir][emitc] Lower arith.index_cast, arith.index_castui, arith.shli, arith.shrui, arith.shrsi (PR #95795)

Corentin Ferry llvmlistbot at llvm.org
Mon Jun 17 07:28:37 PDT 2024


https://github.com/cferry-AMD created https://github.com/llvm/llvm-project/pull/95795

This PR makes use of the newly introduced EmitC types, and it is now possible to lower:
* ops dealing with index types (`index_cast`, `index_castui`),
* ops where `size_t` is used as part of the lowering (`shli`, `shrui`, `shrsi`).

For the `shli`, `shrui`, `shrsi` operations, we have to check for overflow, as overflow is UB per C99 specification, and gives a poison value in the MLIR world. Where the bitwidth is not known (i.e. for variables of type `index`), the check is performed using `sizeof`. It is then up to the target compiler to optimize it away and perform constant propagation.

>From 7f0ab5eda8580c6a1a7d26569ac61cae35189b5f Mon Sep 17 00:00:00 2001
From: Corentin Ferry <corentin.ferry at amd.com>
Date: Mon, 17 Jun 2024 14:13:17 +0100
Subject: [PATCH 1/4] Refactor ArithToEmitC: adaptIntegralTypeSignedness

---
 mlir/include/mlir/Dialect/EmitC/IR/EmitC.td   |  1 +
 .../Conversion/ArithToEmitC/ArithToEmitC.cpp  | 78 +++++++------------
 mlir/lib/Dialect/EmitC/IR/EmitC.cpp           |  6 ++
 .../ArithToEmitC/arith-to-emitc.mlir          |  7 ++
 4 files changed, 40 insertions(+), 52 deletions(-)

diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index 452302c565139..25d1983ec583b 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -288,6 +288,7 @@ def EmitC_CastOp : EmitC_Op<"cast",
   let arguments = (ins EmitCType:$source);
   let results = (outs EmitCType:$dest);
   let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
+  let hasFolder = 1;
 }
 
 def EmitC_CmpOp : EmitC_BinaryOp<"cmp", [CExpression]> {
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index 74f0f61d04a1a..9214bc5b2c13e 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -270,19 +270,11 @@ class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
 
     bool needsUnsigned = needsUnsignedCmp(op.getPredicate());
     emitc::CmpPredicate pred = toEmitCPred(op.getPredicate());
-    Type arithmeticType = type;
-    if (type.isUnsignedInteger() != needsUnsigned) {
-      arithmeticType = rewriter.getIntegerType(type.getIntOrFloatBitWidth(),
-                                               /*isSigned=*/!needsUnsigned);
-    }
-    Value lhs = adaptor.getLhs();
-    Value rhs = adaptor.getRhs();
-    if (arithmeticType != type) {
-      lhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
-                                                    lhs);
-      rhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
-                                                    rhs);
-    }
+
+    Type arithmeticType = adaptIntegralTypeSignedness(type, needsUnsigned);
+    Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
+    Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
+
     rewriter.replaceOpWithNewOp<emitc::CmpOp>(op, op.getType(), pred, lhs, rhs);
     return success();
   }
@@ -328,37 +320,26 @@ class CastConversion : public OpConversionPattern<ArithOp> {
       return success();
     }
 
-    bool isTruncation = operandType.getIntOrFloatBitWidth() >
-                        opReturnType.getIntOrFloatBitWidth();
+    bool isTruncation =
+        (isa<IntegerType>(operandType) && isa<IntegerType>(opReturnType) &&
+         operandType.getIntOrFloatBitWidth() >
+             opReturnType.getIntOrFloatBitWidth());
     bool doUnsigned = castToUnsigned || isTruncation;
 
-    Type castType = opReturnType;
-    // If the op is a ui variant and the type wanted as
-    // return type isn't unsigned, we need to issue an unsigned type to do
-    // the conversion.
-    if (castType.isUnsignedInteger() != doUnsigned) {
-      castType = rewriter.getIntegerType(opReturnType.getIntOrFloatBitWidth(),
-                                         /*isSigned=*/!doUnsigned);
-    }
+    // Adapt the signedness of the result (bitwidth-preserving cast)
+    // This is needed e.g., if the return type is signless.
+    Type castDestType = adaptIntegralTypeSignedness(opReturnType, doUnsigned);
 
-    Value actualOp = adaptor.getIn();
-    // Adapt the signedness of the operand if necessary
-    if (operandType.isUnsignedInteger() != doUnsigned) {
-      Type correctSignednessType =
-          rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(),
-                                  /*isSigned=*/!doUnsigned);
-      actualOp = rewriter.template create<emitc::CastOp>(
-          op.getLoc(), correctSignednessType, actualOp);
-    }
+    // Adapt the signedness of the operand (bitwidth-preserving cast)
+    Type castSrcType = adaptIntegralTypeSignedness(operandType, doUnsigned);
+    Value actualOp = adaptValueType(adaptor.getIn(), rewriter, castSrcType);
 
-    auto result = rewriter.template create<emitc::CastOp>(op.getLoc(), castType,
-                                                          actualOp);
+    // Actual cast (may change bitwidth)
+    auto cast = rewriter.template create<emitc::CastOp>(op.getLoc(),
+                                                        castDestType, actualOp);
 
     // Cast to the expected output type
-    if (castType != opReturnType) {
-      result = rewriter.template create<emitc::CastOp>(op.getLoc(),
-                                                       opReturnType, result);
-    }
+    auto result = adaptValueType(cast, rewriter, opReturnType);
 
     rewriter.replaceOp(op, result);
     return success();
@@ -410,8 +391,6 @@ class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
       return rewriter.notifyMatchFailure(op, "i1 type is not implemented");
     }
 
-    Value lhs = adaptor.getLhs();
-    Value rhs = adaptor.getRhs();
     Type arithmeticType = type;
     if ((type.isSignlessInteger() || type.isSignedInteger()) &&
         !bitEnumContainsAll(op.getOverflowFlags(),
@@ -421,20 +400,15 @@ class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
       arithmeticType = rewriter.getIntegerType(type.getIntOrFloatBitWidth(),
                                                /*isSigned=*/false);
     }
-    if (arithmeticType != type) {
-      lhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
-                                                    lhs);
-      rhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
-                                                    rhs);
-    }
 
-    Value result = rewriter.template create<EmitCOp>(op.getLoc(),
-                                                     arithmeticType, lhs, rhs);
+    Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
+    Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
+
+    Value arithmeticResult = rewriter.template create<EmitCOp>(
+        op.getLoc(), arithmeticType, lhs, rhs);
+
+    Value result = adaptValueType(arithmeticResult, rewriter, type);
 
-    if (arithmeticType != type) {
-      result =
-          rewriter.template create<emitc::CastOp>(op.getLoc(), type, result);
-    }
     rewriter.replaceOp(op, result);
     return success();
   }
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index b2556bb6065d8..c3c9b4e6a1d3e 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -241,6 +241,12 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
        emitc::isSupportedFloatType(output) || isa<emitc::PointerType>(output)));
 }
 
+OpFoldResult emitc::CastOp::fold(FoldAdaptor adaptor) {
+  if (getOperand().getType() == getResult().getType())
+    return getOperand();
+  return nullptr;
+}
+
 //===----------------------------------------------------------------------===//
 // CallOpaqueOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index 71f1a6abd913b..607e5bf9b1a3b 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -466,6 +466,13 @@ func.func @arith_trunci(%arg0: i32) -> i8 {
   // CHECK: emitc.cast %[[Trunc]] : ui8 to i8
   %truncd = arith.trunci %arg0 : i32 to i8
 
+  // CHECK: %[[Const:.*]] = "emitc.constant"
+  // CHECK-SAME: value = 1
+  // CHECK-SAME: () -> i32
+  // CHECK: %[[AndOne:.*]] = emitc.bitwise_and %[[Arg0]], %[[Const]] : (i32, i32) -> i32
+  // CHECK: %[[Conv:.*]] = emitc.cast %[[AndOne]] : i32 to i1
+  %bool = arith.trunci %arg0 : i32 to i1
+
   return %truncd : i8
 }
 

>From f64e96aae717fef79b3efb3a88858fcf725a4aa4 Mon Sep 17 00:00:00 2001
From: Corentin Ferry <corentin.ferry at amd.com>
Date: Mon, 17 Jun 2024 14:43:53 +0100
Subject: [PATCH 2/4] Use new types in EmitC, lower index_cast

---
 .../Conversion/ArithToEmitC/ArithToEmitC.cpp  | 52 ++++++++---
 .../Conversion/ArithToEmitC/CMakeLists.txt    |  1 +
 .../ArithToEmitC/arith-to-emitc.mlir          | 87 +++++++++++++++++--
 3 files changed, 120 insertions(+), 20 deletions(-)

diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index 9214bc5b2c13e..0599083f8f1bf 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -15,6 +15,7 @@
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/EmitC/IR/EmitC.h"
+#include "mlir/Dialect/EmitC/Transforms/TypeConversions.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/Support/LogicalResult.h"
@@ -36,8 +37,11 @@ class ArithConstantOpConversionPattern
   matchAndRewrite(arith::ConstantOp arithConst,
                   arith::ConstantOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    rewriter.replaceOpWithNewOp<emitc::ConstantOp>(
-        arithConst, arithConst.getType(), adaptor.getValue());
+    Type newTy = this->getTypeConverter()->convertType(arithConst.getType());
+    if (!newTy)
+      return rewriter.notifyMatchFailure(arithConst, "type conversion failed");
+    rewriter.replaceOpWithNewOp<emitc::ConstantOp>(arithConst, newTy,
+                                                   adaptor.getValue());
     return success();
   }
 };
@@ -52,6 +56,12 @@ Type adaptIntegralTypeSignedness(Type ty, bool needsUnsigned) {
       return IntegerType::get(ty.getContext(), ty.getIntOrFloatBitWidth(),
                               signedness);
     }
+  } else if (emitc::isPointerWideType(ty)) {
+    if (isa<emitc::SizeTType>(ty) != needsUnsigned) {
+      if (needsUnsigned)
+        return emitc::SizeTType::get(ty.getContext());
+      return emitc::PtrDiffTType::get(ty.getContext());
+    }
   }
   return ty;
 }
@@ -264,8 +274,9 @@ class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
                   ConversionPatternRewriter &rewriter) const override {
 
     Type type = adaptor.getLhs().getType();
-    if (!isa_and_nonnull<IntegerType, IndexType>(type)) {
-      return rewriter.notifyMatchFailure(op, "expected integer or index type");
+    if (type && !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) {
+      return rewriter.notifyMatchFailure(
+          op, "expected integer or size_t/ssize_t type");
     }
 
     bool needsUnsigned = needsUnsignedCmp(op.getPredicate());
@@ -290,8 +301,10 @@ class CastConversion : public OpConversionPattern<ArithOp> {
                   ConversionPatternRewriter &rewriter) const override {
 
     Type opReturnType = this->getTypeConverter()->convertType(op.getType());
-    if (!isa_and_nonnull<IntegerType>(opReturnType))
-      return rewriter.notifyMatchFailure(op, "expected integer result type");
+    if (opReturnType && !(isa_and_nonnull<IntegerType>(opReturnType) ||
+                          emitc::isPointerWideType(opReturnType)))
+      return rewriter.notifyMatchFailure(
+          op, "expected integer or size_t/ssize_t result type");
 
     if (adaptor.getOperands().size() != 1) {
       return rewriter.notifyMatchFailure(
@@ -299,8 +312,10 @@ class CastConversion : public OpConversionPattern<ArithOp> {
     }
 
     Type operandType = adaptor.getIn().getType();
-    if (!isa_and_nonnull<IntegerType>(operandType))
-      return rewriter.notifyMatchFailure(op, "expected integer operand type");
+    if (operandType && !(isa_and_nonnull<IntegerType>(operandType) ||
+                         emitc::isPointerWideType(operandType)))
+      return rewriter.notifyMatchFailure(
+          op, "expected integer or size_t/ssize_t operand type");
 
     // Signed (sign-extending) casts from i1 are not supported.
     if (operandType.isInteger(1) && !castToUnsigned)
@@ -311,8 +326,11 @@ class CastConversion : public OpConversionPattern<ArithOp> {
     // equivalent to (v != 0). Implementing as (bool)(v & 0x01) gives
     // truncation.
     if (opReturnType.isInteger(1)) {
+      Type attrType = (emitc::isPointerWideType(operandType))
+                          ? rewriter.getIndexType()
+                          : operandType;
       auto constOne = rewriter.create<emitc::ConstantOp>(
-          op.getLoc(), operandType, rewriter.getIntegerAttr(operandType, 1));
+          op.getLoc(), operandType, rewriter.getIntegerAttr(attrType, 1));
       auto oneAndOperand = rewriter.create<emitc::BitwiseAndOp>(
           op.getLoc(), operandType, adaptor.getIn(), constOne);
       rewriter.replaceOpWithNewOp<emitc::CastOp>(op, opReturnType,
@@ -365,7 +383,11 @@ class ArithOpConversion final : public OpConversionPattern<ArithOp> {
   matchAndRewrite(ArithOp arithOp, typename ArithOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
 
-    rewriter.template replaceOpWithNewOp<EmitCOp>(arithOp, arithOp.getType(),
+    Type newTy = this->getTypeConverter()->convertType(arithOp.getType());
+    if (!newTy)
+      return rewriter.notifyMatchFailure(arithOp,
+                                         "converting result type failed");
+    rewriter.template replaceOpWithNewOp<EmitCOp>(arithOp, newTy,
                                                   adaptor.getOperands());
 
     return success();
@@ -382,8 +404,10 @@ class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
                   ConversionPatternRewriter &rewriter) const override {
 
     Type type = this->getTypeConverter()->convertType(op.getType());
-    if (!isa_and_nonnull<IntegerType, IndexType>(type)) {
-      return rewriter.notifyMatchFailure(op, "expected integer type");
+    if (type && !(isa_and_nonnull<IntegerType>(type) ||
+                  emitc::isPointerWideType(type))) {
+      return rewriter.notifyMatchFailure(
+          op, "expected integer or size_t/ssize_t/ptrdiff_t type");
     }
 
     if (type.isInteger(1)) {
@@ -578,6 +602,8 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
                                         RewritePatternSet &patterns) {
   MLIRContext *ctx = patterns.getContext();
 
+  mlir::populateEmitCSizeTTypeConversions(typeConverter);
+
   // clang-format off
   patterns.add<
     ArithConstantOpConversionPattern,
@@ -600,6 +626,8 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
     UnsignedCastConversion<arith::TruncIOp>,
     SignedCastConversion<arith::ExtSIOp>,
     UnsignedCastConversion<arith::ExtUIOp>,
+    SignedCastConversion<arith::IndexCastOp>,
+    UnsignedCastConversion<arith::IndexCastUIOp>,
     ItoFCastOpConversion<arith::SIToFPOp>,
     ItoFCastOpConversion<arith::UIToFPOp>,
     FtoICastOpConversion<arith::FPToSIOp>,
diff --git a/mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt b/mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt
index a3784f47c3bc2..730a4b341673d 100644
--- a/mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt
+++ b/mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_conversion_library(MLIRArithToEmitC
   LINK_LIBS PUBLIC
   MLIRArithDialect
   MLIREmitCDialect
+  MLIREmitCTransforms
   MLIRPass
   MLIRTransformUtils
   )
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index 607e5bf9b1a3b..fd19beadd6a06 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -3,7 +3,8 @@
 // CHECK-LABEL: arith_constants
 func.func @arith_constants() {
   // CHECK: emitc.constant
-  // CHECK-SAME: value = 0 : index
+  // CHECK-SAME: value = 0
+  // CHECK-SAME: () -> !emitc.size_t
   %c_index = arith.constant 0 : index
   // CHECK: emitc.constant
   // CHECK-SAME: value = 0 : i32
@@ -75,13 +76,18 @@ func.func @arith_integer_ops_signed_nsw(%arg0: i32, %arg1: i32) {
 // -----
 
 // CHECK-LABEL: arith_index
-func.func @arith_index(%arg0: index, %arg1: index) {
-  // CHECK: emitc.add %arg0, %arg1 : (index, index) -> index
-  %0 = arith.addi %arg0, %arg1 : index
-  // CHECK: emitc.sub %arg0, %arg1 : (index, index) -> index
-  %1 = arith.subi %arg0, %arg1 : index
-  // CHECK: emitc.mul %arg0, %arg1 : (index, index) -> index
-  %2 = arith.muli %arg0, %arg1 : index
+func.func @arith_index(%arg0: i32, %arg1: i32) {
+  // CHECK: %[[CST0:.*]] = emitc.cast %{{.*}} : {{.*}} to !emitc.size_t
+  %cst0 = arith.index_cast %arg0 : i32 to index
+  // CHECK: %[[CST1:.*]] = emitc.cast %{{.*}} : {{.*}} to !emitc.size_t
+  %cst1 = arith.index_cast %arg1 : i32 to index
+
+  // CHECK: emitc.add %[[CST0]], %[[CST1]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
+  %0 = arith.addi %cst0, %cst1 : index
+  // CHECK: emitc.sub %[[CST0]], %[[CST1]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
+  %1 = arith.subi %cst0, %cst1 : index
+  // CHECK: emitc.mul %[[CST0]], %[[CST1]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
+  %2 = arith.muli %cst0, %cst1 : index
 
   return
 }
@@ -420,6 +426,27 @@ func.func @arith_cmpi_predicates(%arg0: i32, %arg1: i32) {
   return
 }
 
+func.func @arith_cmpi_index(%arg0: i32, %arg1: i32) -> i1 {
+  // CHECK-LABEL: arith_cmpi_index
+
+  // CHECK: %[[Cst0:.*]] = emitc.cast %{{.*}} : {{.*}} to !emitc.size_t
+  %idx0 = arith.index_cast %arg0 : i32 to index
+  // CHECK: %[[Cst1:.*]] = emitc.cast %{{.*}} : {{.*}} to !emitc.size_t
+  %idx1 = arith.index_cast %arg0 : i32 to index
+
+  // CHECK-DAG: [[ULT:[^ ]*]] = emitc.cmp lt, %[[Cst0]], %[[Cst1]] : (!emitc.size_t, !emitc.size_t) -> i1
+  %ult = arith.cmpi ult, %idx0, %idx1 : index
+
+  // CHECK-DAG: %[[CastArg0:[^ ]*]] = emitc.cast %[[Cst0]] : !emitc.size_t to !emitc.ptrdiff_t
+  // CHECK-DAG: %[[CastArg1:[^ ]*]] = emitc.cast %[[Cst1]] : !emitc.size_t to !emitc.ptrdiff_t
+  // CHECK-DAG: %[[SLT:[^ ]*]] = emitc.cmp lt, %[[CastArg0]], %[[CastArg1]] : (!emitc.ptrdiff_t, !emitc.ptrdiff_t) -> i1
+  %slt = arith.cmpi slt, %idx0, %idx1 : index
+
+  // CHECK: return %[[SLT]]
+  return %slt: i1
+}
+
+
 // -----
 
 func.func @arith_float_to_int_cast_ops(%arg0: f32, %arg1: f64) {
@@ -525,3 +552,47 @@ func.func @arith_extui_i1_to_i32(%arg0: i1) {
   %idx = arith.extui %arg0 : i1 to i32
   return
 }
+
+// -----
+
+func.func @arith_index_cast(%arg0: i32) -> i32 {
+  // CHECK-LABEL: arith_index_cast
+  // CHECK-SAME: (%[[Arg0:[^ ]*]]: i32)
+  // CHECK: %[[Conv0:.*]] = emitc.cast %[[Arg0]] : i32 to !emitc.ptrdiff_t
+  // CHECK: %[[Conv1:.*]] = emitc.cast %[[Conv0]] : !emitc.ptrdiff_t to !emitc.size_t
+  %idx = arith.index_cast %arg0 : i32 to index
+  // CHECK: %[[Conv2:.*]] = emitc.cast %[[Conv1]] : !emitc.size_t to !emitc.ptrdiff_t
+  // CHECK: %[[Conv3:.*]] = emitc.cast %[[Conv2]] : !emitc.ptrdiff_t to i32
+  %int = arith.index_cast %idx : index to i32
+
+  // CHECK: %[[Const:.*]] = "emitc.constant"
+  // CHECK-SAME: value = 1
+  // CHECK-SAME: () -> !emitc.size_t
+  // CHECK: %[[AndOne:.*]] = emitc.bitwise_and %[[Conv1]], %[[Const]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
+  // CHECK: %[[Conv4:.*]] = emitc.cast %[[AndOne]] : !emitc.size_t to i1
+  %bool = arith.index_cast %idx : index to i1
+
+  return %int : i32
+}
+
+// -----
+
+func.func @arith_index_castui(%arg0: i32) -> i32 {
+  // CHECK-LABEL: arith_index_castui
+  // CHECK-SAME: (%[[Arg0:[^ ]*]]: i32)
+  // CHECK: %[[Conv0:.*]] = emitc.cast %[[Arg0]] : i32 to ui32
+  // CHECK: %[[Conv1:.*]] = emitc.cast %[[Conv0]] : ui32 to !emitc.size_t
+  %idx = arith.index_castui %arg0 : i32 to index
+  // CHECK: %[[Conv2:.*]] = emitc.cast %[[Conv1]] : !emitc.size_t to ui32
+  // CHECK: %[[Conv3:.*]] = emitc.cast %[[Conv2]] : ui32 to i32
+  %int = arith.index_castui %idx : index to i32
+
+  // CHECK: %[[Const:.*]] = "emitc.constant"
+  // CHECK-SAME: value = 1
+  // CHECK-SAME: () -> !emitc.size_t
+  // CHECK: %[[AndOne:.*]] = emitc.bitwise_and %[[Conv1]], %[[Const]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
+  // CHECK: %[[Conv4:.*]] = emitc.cast %[[AndOne]] : !emitc.size_t to i1
+  %bool = arith.index_castui %idx : index to i1
+
+  return %int : i32
+}

>From d889b080419075ba5f6c74fe847ee6b848580516 Mon Sep 17 00:00:00 2001
From: Corentin Ferry <corentin.ferry at amd.com>
Date: Wed, 29 May 2024 10:48:49 +0200
Subject: [PATCH 3/4] Add shift operations

---
 .../Conversion/ArithToEmitC/ArithToEmitC.cpp  |  88 ++++++++++++++
 .../arith-to-emitc-unsupported.mlir           |  24 ++++
 .../ArithToEmitC/arith-to-emitc.mlir          | 110 ++++++++++++++++++
 3 files changed, 222 insertions(+)

diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index 0599083f8f1bf..cae647d357434 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/EmitC/Transforms/TypeConversions.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Region.h"
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/Transforms/DialectConversion.h"
 
@@ -478,6 +479,90 @@ class BitwiseOpConversion : public OpConversionPattern<ArithOp> {
   }
 };
 
+template <typename ArithOp, typename EmitCOp, bool isUnsignedOp>
+class ShiftOpConversion : public OpConversionPattern<ArithOp> {
+public:
+  using OpConversionPattern<ArithOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    Type type = this->getTypeConverter()->convertType(op.getType());
+    if (type && !(isa_and_nonnull<IntegerType>(type) ||
+                  emitc::isPointerWideType(type))) {
+      return rewriter.notifyMatchFailure(
+          op, "expected integer or size_t/ssize_t type");
+    }
+
+    if (type.isInteger(1)) {
+      return rewriter.notifyMatchFailure(op, "i1 type is not implemented");
+    }
+
+    Type arithmeticType = adaptIntegralTypeSignedness(type, isUnsignedOp);
+
+    Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
+    // Shift amount interpreted as unsigned per Arith dialect spec.
+    Type rhsType = adaptIntegralTypeSignedness(adaptor.getRhs().getType(),
+                                               /*needsUnsigned=*/true);
+    Value rhs = adaptValueType(adaptor.getRhs(), rewriter, rhsType);
+
+    // Add a runtime check for overflow
+    Value width;
+    if (emitc::isPointerWideType(type)) {
+      Value eight = rewriter.create<emitc::ConstantOp>(
+          op.getLoc(), rhsType, rewriter.getIndexAttr(8));
+      emitc::CallOpaqueOp sizeOfCall = rewriter.create<emitc::CallOpaqueOp>(
+          op.getLoc(), rhsType, "sizeof", SmallVector<Value, 1>({eight}));
+      width = rewriter.create<emitc::MulOp>(op.getLoc(), rhsType, eight,
+                                            sizeOfCall.getResult(0));
+    } else {
+      width = rewriter.create<emitc::ConstantOp>(
+          op.getLoc(), rhsType,
+          rewriter.getIntegerAttr(rhsType, type.getIntOrFloatBitWidth()));
+    }
+
+    Value excessCheck = rewriter.create<emitc::CmpOp>(
+        op.getLoc(), rewriter.getI1Type(), emitc::CmpPredicate::lt, rhs, width);
+
+    // Any concrete value is a valid refinement of poison.
+    Value poison = rewriter.create<emitc::ConstantOp>(
+        op.getLoc(), arithmeticType,
+        (isa<IntegerType>(arithmeticType)
+             ? rewriter.getIntegerAttr(arithmeticType, 0)
+             : rewriter.getIndexAttr(0)));
+
+    emitc::ExpressionOp ternary = rewriter.create<emitc::ExpressionOp>(
+        op.getLoc(), arithmeticType, /*do_not_inline=*/false);
+    Block &bodyBlock = ternary.getBodyRegion().emplaceBlock();
+    auto currentPoint = rewriter.getInsertionPoint();
+    rewriter.setInsertionPointToStart(&bodyBlock);
+    Value arithmeticResult =
+        rewriter.create<EmitCOp>(op.getLoc(), arithmeticType, lhs, rhs);
+    Value resultOrPoison = rewriter.create<emitc::ConditionalOp>(
+        op.getLoc(), arithmeticType, excessCheck, arithmeticResult, poison);
+    rewriter.create<emitc::YieldOp>(op.getLoc(), resultOrPoison);
+    rewriter.setInsertionPoint(op->getBlock(), currentPoint);
+
+    Value result = adaptValueType(ternary, rewriter, type);
+
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
+template <typename ArithOp, typename EmitCOp>
+class SignedShiftOpConversion final
+    : public ShiftOpConversion<ArithOp, EmitCOp, false> {
+  using ShiftOpConversion<ArithOp, EmitCOp, false>::ShiftOpConversion;
+};
+
+template <typename ArithOp, typename EmitCOp>
+class UnsignedShiftOpConversion final
+    : public ShiftOpConversion<ArithOp, EmitCOp, true> {
+  using ShiftOpConversion<ArithOp, EmitCOp, true>::ShiftOpConversion;
+};
+
 class SelectOpConversion : public OpConversionPattern<arith::SelectOp> {
 public:
   using OpConversionPattern<arith::SelectOp>::OpConversionPattern;
@@ -619,6 +704,9 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
     BitwiseOpConversion<arith::AndIOp, emitc::BitwiseAndOp>,
     BitwiseOpConversion<arith::OrIOp, emitc::BitwiseOrOp>,
     BitwiseOpConversion<arith::XOrIOp, emitc::BitwiseXorOp>,
+    UnsignedShiftOpConversion<arith::ShLIOp, emitc::BitwiseLeftShiftOp>,
+    SignedShiftOpConversion<arith::ShRSIOp, emitc::BitwiseRightShiftOp>,
+    UnsignedShiftOpConversion<arith::ShRUIOp, emitc::BitwiseRightShiftOp>,
     CmpFOpConversion,
     CmpIOpConversion,
     SelectOpConversion,
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
index c07289109e6dd..d3b31c03c5d13 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
@@ -86,3 +86,27 @@ func.func @arith_extsi_i1_to_i32(%arg0: i1) {
   %idx = arith.extsi %arg0 : i1 to i32
   return
 }
+
+// -----
+
+func.func @arith_shli_i1(%arg0: i1, %arg1: i1) {
+  // expected-error @+1 {{failed to legalize operation 'arith.shli'}}
+  %shli = arith.shli %arg0, %arg1 : i1
+  return
+}
+
+// -----
+
+func.func @arith_shrsi_i1(%arg0: i1, %arg1: i1) {
+  // expected-error @+1 {{failed to legalize operation 'arith.shrsi'}}
+  %shrsi = arith.shrsi %arg0, %arg1 : i1
+  return
+}
+
+// -----
+
+func.func @arith_shrui_i1(%arg0: i1, %arg1: i1) {
+  // expected-error @+1 {{failed to legalize operation 'arith.shrui'}}
+  %shrui = arith.shrui %arg0, %arg1 : i1
+  return
+}
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index fd19beadd6a06..bb743901b599c 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -144,6 +144,116 @@ func.func @arith_signed_integer_div_rem(%arg0: i32, %arg1: i32) {
 
 // -----
 
+// CHECK-LABEL: arith_shift_left
+// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32
+func.func @arith_shift_left(%arg0: i32, %arg1: i32) {
+  // CHECK-DAG: %[[C1:[^ ]*]] = emitc.cast %[[ARG0]] : i32 to ui32
+  // CHECK-DAG: %[[C2:[^ ]*]] = emitc.cast %[[ARG1]] : i32 to ui32
+  // CHECK-DAG: %[[SizeConstant:[^ ]*]] = "emitc.constant"{{.*}}value = 32
+  // CHECK-DAG: %[[CmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[C2]], %[[SizeConstant]] : (ui32, ui32) -> i1
+  // CHECK-DAG: %[[Zero:[^ ]*]] = "emitc.constant"{{.*}}value = 0
+  // CHECK: %[[ShiftRes:[^ ]*]] = emitc.expression : ui32
+  // CHECK: %[[SHL:[^ ]*]] = emitc.bitwise_left_shift %[[C1]], %[[C2]] : (ui32, ui32) -> ui32
+  // CHECK: %[[Ternary:[^ ]*]] = emitc.conditional %[[CmpNoExcess]], %[[SHL]], %[[Zero]] : ui32
+  // CHECK: emitc.yield %[[Ternary]] : ui32
+  // CHECK: emitc.cast %[[ShiftRes]] : ui32 to i32
+  %1 = arith.shli %arg0, %arg1 : i32
+  return
+}
+
+// -----
+
+// CHECK-LABEL: arith_shift_right
+// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32
+func.func @arith_shift_right(%arg0: i32, %arg1: i32) {
+  // CHECK-DAG: %[[C1:[^ ]*]] = emitc.cast %[[ARG0]] : i32 to ui32
+  // CHECK-DAG: %[[C2:[^ ]*]] = emitc.cast %[[ARG1]] : i32 to ui32
+  // CHECK-DAG: %[[SizeConstant:[^ ]*]] = "emitc.constant"{{.*}}value = 32{{.*}}ui32
+  // CHECK-DAG: %[[CmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[C2]], %[[SizeConstant]] : (ui32, ui32) -> i1
+  // CHECK-DAG: %[[Zero:[^ ]*]] = "emitc.constant"{{.*}}value = 0{{.*}}ui32
+  // CHECK: %[[ShiftRes:[^ ]*]] = emitc.expression : ui32
+  // CHECK: %[[SHR:[^ ]*]] = emitc.bitwise_right_shift %[[C1]], %[[C2]] : (ui32, ui32) -> ui32
+  // CHECK: %[[Ternary:[^ ]*]] = emitc.conditional %[[CmpNoExcess]], %[[SHR]], %[[Zero]] : ui32
+  // CHECK: emitc.yield %[[Ternary]] : ui32
+  // CHECK: emitc.cast %[[ShiftRes]] : ui32 to i32
+  %2 = arith.shrui %arg0, %arg1 : i32
+
+  // CHECK-DAG: %[[SC2:[^ ]*]] = emitc.cast %[[ARG1]] : i32 to ui32
+  // CHECK-DAG: %[[SSizeConstant:[^ ]*]] = "emitc.constant"{{.*}}value = 32{{.*}}ui32
+  // CHECK-DAG: %[[SCmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[SC2]], %[[SSizeConstant]] : (ui32, ui32) -> i1
+  // CHECK-DAG: %[[SZero:[^ ]*]] = "emitc.constant"{{.*}}value = 0{{.*}}i32
+  // CHECK: %[[SShiftRes:[^ ]*]] = emitc.expression : i32
+  // CHECK: %[[SHRSI:[^ ]*]] = emitc.bitwise_right_shift %[[ARG0]], %[[SC2]] : (i32, ui32) -> i32
+  // CHECK: %[[STernary:[^ ]*]] = emitc.conditional %[[SCmpNoExcess]], %[[SHRSI]], %[[SZero]] : i32
+  // CHECK: emitc.yield %[[STernary]] : i32
+  %3 = arith.shrsi %arg0, %arg1 : i32
+
+  return
+}
+
+// -----
+
+// CHECK-LABEL: arith_shift_left_index
+// CHECK-SAME: %[[AMOUNT:.*]]: i32
+func.func @arith_shift_left_index(%amount: i32) {
+  %cst0 = "arith.constant"() {value = 42 : index} : () -> (index)
+  %cast1 = arith.index_cast %amount : i32 to index
+  // CHECK-DAG: %[[C1:[^ ]*]] = "emitc.constant"(){{.*}}value = 42{{.*}}!emitc.size_t
+  // CHECK-DAG: %[[Cast1:[^ ]*]] = emitc.cast %[[AMOUNT]] : i32 to !emitc.ssize_t
+  // CHECK-DAG: %[[AmountIdx:[^ ]*]] = emitc.cast %[[Cast1]] : !emitc.ssize_t to !emitc.size_t
+  // CHECK-DAG: %[[Byte:[^ ]*]] = "emitc.constant"{{.*}}value = 8{{.*}}index
+  // CHECK-DAG: %[[SizeOf:[^ ]*]] = emitc.call_opaque "sizeof"(%[[Byte]]) : (!emitc.size_t) -> !emitc.size_t
+  // CHECK-DAG: %[[SizeConstant:[^ ]*]] = emitc.mul %[[Byte]], %[[SizeOf]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
+  // CHECK-DAG: %[[CmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[AmountIdx]], %[[SizeConstant]] : (!emitc.size_t, !emitc.size_t) -> i1
+  // CHECK-DAG: %[[Zero:[^ ]*]] = "emitc.constant"{{.*}}value = 0
+  // CHECK: %[[ShiftRes:[^ ]*]] = emitc.expression : !emitc.size_t
+  // CHECK: %[[SHL:[^ ]*]] = emitc.bitwise_left_shift %[[C1]], %[[AmountIdx]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
+  // CHECK: %[[Ternary:[^ ]*]] = emitc.conditional %[[CmpNoExcess]], %[[SHL]], %[[Zero]] : !emitc.size_t
+  // CHECK: emitc.yield %[[Ternary]] : !emitc.size_t
+  %1 = arith.shli %cst0, %cast1 : index
+  return
+}
+
+// -----
+
+// CHECK-LABEL: arith_shift_right_index
+// CHECK-SAME: %[[AMOUNT:.*]]: i32
+func.func @arith_shift_right_index(%amount: i32) {
+  // CHECK-DAG: %[[C1:[^ ]*]] = "emitc.constant"(){{.*}}value = 42{{.*}}!emitc.size_t
+  // CHECK-DAG: %[[Cast1:[^ ]*]] = emitc.cast %[[AMOUNT]] : i32 to !emitc.ssize_t
+  // CHECK-DAG: %[[AmountIdx:[^ ]*]] = emitc.cast %[[Cast1]] : !emitc.ssize_t to !emitc.size_t
+  %arg0 = "arith.constant"() {value = 42 : index} : () -> (index)
+  %arg1 = arith.index_cast %amount : i32 to index
+
+  // CHECK-DAG: %[[Byte:[^ ]*]] = "emitc.constant"{{.*}}value = 8{{.*}}index
+  // CHECK-DAG: %[[SizeOf:[^ ]*]] = emitc.call_opaque "sizeof"(%[[Byte]]) : (!emitc.size_t) -> !emitc.size_t
+  // CHECK-DAG: %[[SizeConstant:[^ ]*]] = emitc.mul %[[Byte]], %[[SizeOf]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
+  // CHECK-DAG: %[[CmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[AmountIdx]], %[[SizeConstant]] : (!emitc.size_t, !emitc.size_t) -> i1
+  // CHECK-DAG: %[[Zero:[^ ]*]] = "emitc.constant"{{.*}}value = 0{{.*}}!emitc.size_t
+  // CHECK: %[[ShiftRes:[^ ]*]] = emitc.expression : !emitc.size_t
+  // CHECK: %[[SHR:[^ ]*]] = emitc.bitwise_right_shift %[[C1]], %[[AmountIdx]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
+  // CHECK: %[[Ternary:[^ ]*]] = emitc.conditional %[[CmpNoExcess]], %[[SHR]], %[[Zero]] : !emitc.size_t
+  // CHECK: emitc.yield %[[Ternary]] : !emitc.size_t
+  %2 = arith.shrui %arg0, %arg1 : index
+
+  // CHECK-DAG: %[[SC1:[^ ]*]] = emitc.cast %[[C1]] : !emitc.size_t to !emitc.ssize_t
+  // CHECK-DAG: %[[SByte:[^ ]*]] = "emitc.constant"{{.*}}value = 8{{.*}}index{{.*}}!emitc.size_t
+  // CHECK-DAG: %[[SSizeOf:[^ ]*]] = emitc.call_opaque "sizeof"(%[[SByte]]) : (!emitc.size_t) -> !emitc.size_t
+  // CHECK-DAG: %[[SSizeConstant:[^ ]*]] = emitc.mul %[[SByte]], %[[SSizeOf]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
+  // CHECK-DAG: %[[SCmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[AmountIdx]], %[[SSizeConstant]] : (!emitc.size_t, !emitc.size_t) -> i1
+  // CHECK-DAG: %[[SZero:[^ ]*]] = "emitc.constant"{{.*}}value = 0{{.*}}!emitc.ssize_t
+  // CHECK: %[[SShiftRes:[^ ]*]] = emitc.expression : !emitc.ssize_t
+  // CHECK: %[[SHRSI:[^ ]*]] = emitc.bitwise_right_shift %[[SC1]], %[[AmountIdx]] : (!emitc.ssize_t, !emitc.size_t) -> !emitc.ssize_t
+  // CHECK: %[[STernary:[^ ]*]] = emitc.conditional %[[SCmpNoExcess]], %[[SHRSI]], %[[SZero]] : !emitc.ssize_t
+  // CHECK: emitc.yield %[[STernary]] : !emitc.ssize_t
+  // CHECK: emitc.cast %[[SShiftRes]] : !emitc.ssize_t to !emitc.size_t
+  %3 = arith.shrsi %arg0, %arg1 : index
+
+  return
+}
+
+// -----
+
 func.func @arith_select(%arg0: i1, %arg1: tensor<8xi32>, %arg2: tensor<8xi32>) -> () {
   // CHECK: [[V0:[^ ]*]] = emitc.conditional %arg0, %arg1, %arg2 : tensor<8xi32>
   %0 = arith.select %arg0, %arg1, %arg2 : i1, tensor<8xi32>

>From a05c6aab39b2a14d4d134d98f89c0011faba38dd Mon Sep 17 00:00:00 2001
From: Corentin Ferry <corentin.ferry at amd.com>
Date: Mon, 17 Jun 2024 15:25:11 +0100
Subject: [PATCH 4/4] Fix type mismatches, possibly null operands

---
 .../Conversion/ArithToEmitC/ArithToEmitC.cpp  | 18 +++++++--------
 .../ArithToEmitC/arith-to-emitc.mlir          | 22 +++++++++----------
 2 files changed, 20 insertions(+), 20 deletions(-)

diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index cae647d357434..b4f3acb6127fc 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -275,9 +275,9 @@ class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
                   ConversionPatternRewriter &rewriter) const override {
 
     Type type = adaptor.getLhs().getType();
-    if (type && !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) {
+    if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) {
       return rewriter.notifyMatchFailure(
-          op, "expected integer or size_t/ssize_t type");
+          op, "expected integer or size_t/ssize_t/ptrdiff_t type");
     }
 
     bool needsUnsigned = needsUnsignedCmp(op.getPredicate());
@@ -302,10 +302,10 @@ class CastConversion : public OpConversionPattern<ArithOp> {
                   ConversionPatternRewriter &rewriter) const override {
 
     Type opReturnType = this->getTypeConverter()->convertType(op.getType());
-    if (opReturnType && !(isa_and_nonnull<IntegerType>(opReturnType) ||
+    if (!opReturnType || !(isa<IntegerType>(opReturnType) ||
                           emitc::isPointerWideType(opReturnType)))
       return rewriter.notifyMatchFailure(
-          op, "expected integer or size_t/ssize_t result type");
+          op, "expected integer or size_t/ssize_t/ptrdiff_t result type");
 
     if (adaptor.getOperands().size() != 1) {
       return rewriter.notifyMatchFailure(
@@ -313,10 +313,10 @@ class CastConversion : public OpConversionPattern<ArithOp> {
     }
 
     Type operandType = adaptor.getIn().getType();
-    if (operandType && !(isa_and_nonnull<IntegerType>(operandType) ||
+    if (!operandType || !(isa<IntegerType>(operandType) ||
                          emitc::isPointerWideType(operandType)))
       return rewriter.notifyMatchFailure(
-          op, "expected integer or size_t/ssize_t operand type");
+          op, "expected integer or size_t/ssize_t/ptrdiff_t operand type");
 
     // Signed (sign-extending) casts from i1 are not supported.
     if (operandType.isInteger(1) && !castToUnsigned)
@@ -405,7 +405,7 @@ class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
                   ConversionPatternRewriter &rewriter) const override {
 
     Type type = this->getTypeConverter()->convertType(op.getType());
-    if (type && !(isa_and_nonnull<IntegerType>(type) ||
+    if (!type || !(isa_and_nonnull<IntegerType>(type) ||
                   emitc::isPointerWideType(type))) {
       return rewriter.notifyMatchFailure(
           op, "expected integer or size_t/ssize_t/ptrdiff_t type");
@@ -489,10 +489,10 @@ class ShiftOpConversion : public OpConversionPattern<ArithOp> {
                   ConversionPatternRewriter &rewriter) const override {
 
     Type type = this->getTypeConverter()->convertType(op.getType());
-    if (type && !(isa_and_nonnull<IntegerType>(type) ||
+    if (!type || !(isa_and_nonnull<IntegerType>(type) ||
                   emitc::isPointerWideType(type))) {
       return rewriter.notifyMatchFailure(
-          op, "expected integer or size_t/ssize_t type");
+          op, "expected integer or size_t/ssize_t/ptrdiff_t type");
     }
 
     if (type.isInteger(1)) {
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index bb743901b599c..20df06ea7bd91 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -199,8 +199,8 @@ func.func @arith_shift_left_index(%amount: i32) {
   %cst0 = "arith.constant"() {value = 42 : index} : () -> (index)
   %cast1 = arith.index_cast %amount : i32 to index
   // CHECK-DAG: %[[C1:[^ ]*]] = "emitc.constant"(){{.*}}value = 42{{.*}}!emitc.size_t
-  // CHECK-DAG: %[[Cast1:[^ ]*]] = emitc.cast %[[AMOUNT]] : i32 to !emitc.ssize_t
-  // CHECK-DAG: %[[AmountIdx:[^ ]*]] = emitc.cast %[[Cast1]] : !emitc.ssize_t to !emitc.size_t
+  // CHECK-DAG: %[[Cast1:[^ ]*]] = emitc.cast %[[AMOUNT]] : i32 to !emitc.ptrdiff_t
+  // CHECK-DAG: %[[AmountIdx:[^ ]*]] = emitc.cast %[[Cast1]] : !emitc.ptrdiff_t to !emitc.size_t
   // CHECK-DAG: %[[Byte:[^ ]*]] = "emitc.constant"{{.*}}value = 8{{.*}}index
   // CHECK-DAG: %[[SizeOf:[^ ]*]] = emitc.call_opaque "sizeof"(%[[Byte]]) : (!emitc.size_t) -> !emitc.size_t
   // CHECK-DAG: %[[SizeConstant:[^ ]*]] = emitc.mul %[[Byte]], %[[SizeOf]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
@@ -220,8 +220,8 @@ func.func @arith_shift_left_index(%amount: i32) {
 // CHECK-SAME: %[[AMOUNT:.*]]: i32
 func.func @arith_shift_right_index(%amount: i32) {
   // CHECK-DAG: %[[C1:[^ ]*]] = "emitc.constant"(){{.*}}value = 42{{.*}}!emitc.size_t
-  // CHECK-DAG: %[[Cast1:[^ ]*]] = emitc.cast %[[AMOUNT]] : i32 to !emitc.ssize_t
-  // CHECK-DAG: %[[AmountIdx:[^ ]*]] = emitc.cast %[[Cast1]] : !emitc.ssize_t to !emitc.size_t
+  // CHECK-DAG: %[[Cast1:[^ ]*]] = emitc.cast %[[AMOUNT]] : i32 to !emitc.ptrdiff_t
+  // CHECK-DAG: %[[AmountIdx:[^ ]*]] = emitc.cast %[[Cast1]] : !emitc.ptrdiff_t to !emitc.size_t
   %arg0 = "arith.constant"() {value = 42 : index} : () -> (index)
   %arg1 = arith.index_cast %amount : i32 to index
 
@@ -236,17 +236,17 @@ func.func @arith_shift_right_index(%amount: i32) {
   // CHECK: emitc.yield %[[Ternary]] : !emitc.size_t
   %2 = arith.shrui %arg0, %arg1 : index
 
-  // CHECK-DAG: %[[SC1:[^ ]*]] = emitc.cast %[[C1]] : !emitc.size_t to !emitc.ssize_t
+  // CHECK-DAG: %[[SC1:[^ ]*]] = emitc.cast %[[C1]] : !emitc.size_t to !emitc.ptrdiff_t
   // CHECK-DAG: %[[SByte:[^ ]*]] = "emitc.constant"{{.*}}value = 8{{.*}}index{{.*}}!emitc.size_t
   // CHECK-DAG: %[[SSizeOf:[^ ]*]] = emitc.call_opaque "sizeof"(%[[SByte]]) : (!emitc.size_t) -> !emitc.size_t
   // CHECK-DAG: %[[SSizeConstant:[^ ]*]] = emitc.mul %[[SByte]], %[[SSizeOf]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
   // CHECK-DAG: %[[SCmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[AmountIdx]], %[[SSizeConstant]] : (!emitc.size_t, !emitc.size_t) -> i1
-  // CHECK-DAG: %[[SZero:[^ ]*]] = "emitc.constant"{{.*}}value = 0{{.*}}!emitc.ssize_t
-  // CHECK: %[[SShiftRes:[^ ]*]] = emitc.expression : !emitc.ssize_t
-  // CHECK: %[[SHRSI:[^ ]*]] = emitc.bitwise_right_shift %[[SC1]], %[[AmountIdx]] : (!emitc.ssize_t, !emitc.size_t) -> !emitc.ssize_t
-  // CHECK: %[[STernary:[^ ]*]] = emitc.conditional %[[SCmpNoExcess]], %[[SHRSI]], %[[SZero]] : !emitc.ssize_t
-  // CHECK: emitc.yield %[[STernary]] : !emitc.ssize_t
-  // CHECK: emitc.cast %[[SShiftRes]] : !emitc.ssize_t to !emitc.size_t
+  // CHECK-DAG: %[[SZero:[^ ]*]] = "emitc.constant"{{.*}}value = 0{{.*}}!emitc.ptrdiff_t
+  // CHECK: %[[SShiftRes:[^ ]*]] = emitc.expression : !emitc.ptrdiff_t
+  // CHECK: %[[SHRSI:[^ ]*]] = emitc.bitwise_right_shift %[[SC1]], %[[AmountIdx]] : (!emitc.ptrdiff_t, !emitc.size_t) -> !emitc.ptrdiff_t
+  // CHECK: %[[STernary:[^ ]*]] = emitc.conditional %[[SCmpNoExcess]], %[[SHRSI]], %[[SZero]] : !emitc.ptrdiff_t
+  // CHECK: emitc.yield %[[STernary]] : !emitc.ptrdiff_t
+  // CHECK: emitc.cast %[[SShiftRes]] : !emitc.ptrdiff_t to !emitc.size_t
   %3 = arith.shrsi %arg0, %arg1 : index
 
   return



More information about the Mlir-commits mailing list