[Mlir-commits] [mlir] 519175c - [mlir][emitc] Refactor ArithToEmitC: perform sign adaptation, type conversions / cast insertion in a single place (#95789)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jun 19 00:19:36 PDT 2024


Author: Corentin Ferry
Date: 2024-06-19T09:19:33+02:00
New Revision: 519175c3f5d844bac0cf3173396dc41db2873e1d

URL: https://github.com/llvm/llvm-project/commit/519175c3f5d844bac0cf3173396dc41db2873e1d
DIFF: https://github.com/llvm/llvm-project/commit/519175c3f5d844bac0cf3173396dc41db2873e1d.diff

LOG: [mlir][emitc] Refactor ArithToEmitC: perform sign adaptation, type conversions / cast insertion in a single place (#95789)

Factor EmitC type signedness adaptation and cast operations in ArithToEmitC using adaptValueType and adaptIntegralTypeSignedness.

Added: 
    

Modified: 
    mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
    mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index 27913dff8cb85..93717e3b02ef0 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -270,19 +270,11 @@ class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
 
     bool needsUnsigned = needsUnsignedCmp(op.getPredicate());
     emitc::CmpPredicate pred = toEmitCPred(op.getPredicate());
-    Type arithmeticType = type;
-    if (type.isUnsignedInteger() != needsUnsigned) {
-      arithmeticType = rewriter.getIntegerType(type.getIntOrFloatBitWidth(),
-                                               /*isSigned=*/!needsUnsigned);
-    }
-    Value lhs = adaptor.getLhs();
-    Value rhs = adaptor.getRhs();
-    if (arithmeticType != type) {
-      lhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
-                                                    lhs);
-      rhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
-                                                    rhs);
-    }
+
+    Type arithmeticType = adaptIntegralTypeSignedness(type, needsUnsigned);
+    Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
+    Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
+
     rewriter.replaceOpWithNewOp<emitc::CmpOp>(op, op.getType(), pred, lhs, rhs);
     return success();
   }
@@ -356,37 +348,26 @@ class CastConversion : public OpConversionPattern<ArithOp> {
       return success();
     }
 
-    bool isTruncation = operandType.getIntOrFloatBitWidth() >
-                        opReturnType.getIntOrFloatBitWidth();
+    bool isTruncation =
+        (isa<IntegerType>(operandType) && isa<IntegerType>(opReturnType) &&
+         operandType.getIntOrFloatBitWidth() >
+             opReturnType.getIntOrFloatBitWidth());
     bool doUnsigned = castToUnsigned || isTruncation;
 
-    Type castType = opReturnType;
-    // If the op is a ui variant and the type wanted as
-    // return type isn't unsigned, we need to issue an unsigned type to do
-    // the conversion.
-    if (castType.isUnsignedInteger() != doUnsigned) {
-      castType = rewriter.getIntegerType(opReturnType.getIntOrFloatBitWidth(),
-                                         /*isSigned=*/!doUnsigned);
-    }
+    // Adapt the signedness of the result (bitwidth-preserving cast)
+    // This is needed e.g., if the return type is signless.
+    Type castDestType = adaptIntegralTypeSignedness(opReturnType, doUnsigned);
 
-    Value actualOp = adaptor.getIn();
-    // Adapt the signedness of the operand if necessary
-    if (operandType.isUnsignedInteger() != doUnsigned) {
-      Type correctSignednessType =
-          rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(),
-                                  /*isSigned=*/!doUnsigned);
-      actualOp = rewriter.template create<emitc::CastOp>(
-          op.getLoc(), correctSignednessType, actualOp);
-    }
+    // Adapt the signedness of the operand (bitwidth-preserving cast)
+    Type castSrcType = adaptIntegralTypeSignedness(operandType, doUnsigned);
+    Value actualOp = adaptValueType(adaptor.getIn(), rewriter, castSrcType);
 
-    auto result = rewriter.template create<emitc::CastOp>(op.getLoc(), castType,
-                                                          actualOp);
+    // Actual cast (may change bitwidth)
+    auto cast = rewriter.template create<emitc::CastOp>(op.getLoc(),
+                                                        castDestType, actualOp);
 
     // Cast to the expected output type
-    if (castType != opReturnType) {
-      result = rewriter.template create<emitc::CastOp>(op.getLoc(),
-                                                       opReturnType, result);
-    }
+    auto result = adaptValueType(cast, rewriter, opReturnType);
 
     rewriter.replaceOp(op, result);
     return success();
@@ -438,8 +419,6 @@ class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
       return rewriter.notifyMatchFailure(op, "i1 type is not implemented");
     }
 
-    Value lhs = adaptor.getLhs();
-    Value rhs = adaptor.getRhs();
     Type arithmeticType = type;
     if ((type.isSignlessInteger() || type.isSignedInteger()) &&
         !bitEnumContainsAll(op.getOverflowFlags(),
@@ -449,20 +428,15 @@ class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
       arithmeticType = rewriter.getIntegerType(type.getIntOrFloatBitWidth(),
                                                /*isSigned=*/false);
     }
-    if (arithmeticType != type) {
-      lhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
-                                                    lhs);
-      rhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
-                                                    rhs);
-    }
 
-    Value result = rewriter.template create<EmitCOp>(op.getLoc(),
-                                                     arithmeticType, lhs, rhs);
+    Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
+    Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
+
+    Value arithmeticResult = rewriter.template create<EmitCOp>(
+        op.getLoc(), arithmeticType, lhs, rhs);
+
+    Value result = adaptValueType(arithmeticResult, rewriter, type);
 
-    if (arithmeticType != type) {
-      result =
-          rewriter.template create<emitc::CastOp>(op.getLoc(), type, result);
-    }
     rewriter.replaceOp(op, result);
     return success();
   }

diff  --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index 667ff795178a6..0289b7dc0728f 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -477,6 +477,13 @@ func.func @arith_trunci(%arg0: i32) -> i8 {
   // CHECK: emitc.cast %[[Trunc]] : ui8 to i8
   %truncd = arith.trunci %arg0 : i32 to i8
 
+  // CHECK: %[[Const:.*]] = "emitc.constant"
+  // CHECK-SAME: value = 1
+  // CHECK-SAME: () -> i32
+  // CHECK: %[[AndOne:.*]] = emitc.bitwise_and %[[Arg0]], %[[Const]] : (i32, i32) -> i32
+  // CHECK: %[[Conv:.*]] = emitc.cast %[[AndOne]] : i32 to i1
+  %bool = arith.trunci %arg0 : i32 to i1
+
   return %truncd : i8
 }
 


        


More information about the Mlir-commits mailing list