[Mlir-commits] [mlir] [mlir][emitc] Refactor ArithToEmitC: perform sign adaptation, type conversions / cast insertion in a single place (PR #95789)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jun 17 07:06:58 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-emitc

@llvm/pr-subscribers-mlir

Author: Corentin Ferry (cferry-AMD)

<details>
<summary>Changes</summary>

This PR lays the ground for a next PR that will use the newly introduced EmitC types. 

When certain values have to be interpreted as signed/unsigned (e.g. for ops like `extsi` and `extui`), type conversions in EmitC may be required (from `int` to `unsigned int` for instance), and such conversions are not needed in `arith` ops that only take signless values as operands. Lowering of such ops now uses `adaptIntegralTypeSignedness` (to give out the type with the right signedness) and `adaptValueType` (to insert an explicit cast of a value to a given EmitC type). A folder is added to `emitc::CastOp` so that any cast from self to self will fold away, so the source and destination types don't need to be checked for equality.

While doing this refactoring, emerged the need for one more `trunci` case (to i1, which is well-defined and uses a special case of `CastOpConversion`).

---
Full diff: https://github.com/llvm/llvm-project/pull/95789.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/EmitC/IR/EmitC.td (+1) 
- (modified) mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp (+26-52) 
- (modified) mlir/lib/Dialect/EmitC/IR/EmitC.cpp (+6) 
- (modified) mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir (+7) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index 452302c565139..25d1983ec583b 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -288,6 +288,7 @@ def EmitC_CastOp : EmitC_Op<"cast",
   let arguments = (ins EmitCType:$source);
   let results = (outs EmitCType:$dest);
   let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
+  let hasFolder = 1;
 }
 
 def EmitC_CmpOp : EmitC_BinaryOp<"cmp", [CExpression]> {
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index 74f0f61d04a1a..9214bc5b2c13e 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -270,19 +270,11 @@ class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
 
     bool needsUnsigned = needsUnsignedCmp(op.getPredicate());
     emitc::CmpPredicate pred = toEmitCPred(op.getPredicate());
-    Type arithmeticType = type;
-    if (type.isUnsignedInteger() != needsUnsigned) {
-      arithmeticType = rewriter.getIntegerType(type.getIntOrFloatBitWidth(),
-                                               /*isSigned=*/!needsUnsigned);
-    }
-    Value lhs = adaptor.getLhs();
-    Value rhs = adaptor.getRhs();
-    if (arithmeticType != type) {
-      lhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
-                                                    lhs);
-      rhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
-                                                    rhs);
-    }
+
+    Type arithmeticType = adaptIntegralTypeSignedness(type, needsUnsigned);
+    Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
+    Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
+
     rewriter.replaceOpWithNewOp<emitc::CmpOp>(op, op.getType(), pred, lhs, rhs);
     return success();
   }
@@ -328,37 +320,26 @@ class CastConversion : public OpConversionPattern<ArithOp> {
       return success();
     }
 
-    bool isTruncation = operandType.getIntOrFloatBitWidth() >
-                        opReturnType.getIntOrFloatBitWidth();
+    bool isTruncation =
+        (isa<IntegerType>(operandType) && isa<IntegerType>(opReturnType) &&
+         operandType.getIntOrFloatBitWidth() >
+             opReturnType.getIntOrFloatBitWidth());
     bool doUnsigned = castToUnsigned || isTruncation;
 
-    Type castType = opReturnType;
-    // If the op is a ui variant and the type wanted as
-    // return type isn't unsigned, we need to issue an unsigned type to do
-    // the conversion.
-    if (castType.isUnsignedInteger() != doUnsigned) {
-      castType = rewriter.getIntegerType(opReturnType.getIntOrFloatBitWidth(),
-                                         /*isSigned=*/!doUnsigned);
-    }
+    // Adapt the signedness of the result (bitwidth-preserving cast)
+    // This is needed e.g., if the return type is signless.
+    Type castDestType = adaptIntegralTypeSignedness(opReturnType, doUnsigned);
 
-    Value actualOp = adaptor.getIn();
-    // Adapt the signedness of the operand if necessary
-    if (operandType.isUnsignedInteger() != doUnsigned) {
-      Type correctSignednessType =
-          rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(),
-                                  /*isSigned=*/!doUnsigned);
-      actualOp = rewriter.template create<emitc::CastOp>(
-          op.getLoc(), correctSignednessType, actualOp);
-    }
+    // Adapt the signedness of the operand (bitwidth-preserving cast)
+    Type castSrcType = adaptIntegralTypeSignedness(operandType, doUnsigned);
+    Value actualOp = adaptValueType(adaptor.getIn(), rewriter, castSrcType);
 
-    auto result = rewriter.template create<emitc::CastOp>(op.getLoc(), castType,
-                                                          actualOp);
+    // Actual cast (may change bitwidth)
+    auto cast = rewriter.template create<emitc::CastOp>(op.getLoc(),
+                                                        castDestType, actualOp);
 
     // Cast to the expected output type
-    if (castType != opReturnType) {
-      result = rewriter.template create<emitc::CastOp>(op.getLoc(),
-                                                       opReturnType, result);
-    }
+    auto result = adaptValueType(cast, rewriter, opReturnType);
 
     rewriter.replaceOp(op, result);
     return success();
@@ -410,8 +391,6 @@ class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
       return rewriter.notifyMatchFailure(op, "i1 type is not implemented");
     }
 
-    Value lhs = adaptor.getLhs();
-    Value rhs = adaptor.getRhs();
     Type arithmeticType = type;
     if ((type.isSignlessInteger() || type.isSignedInteger()) &&
         !bitEnumContainsAll(op.getOverflowFlags(),
@@ -421,20 +400,15 @@ class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
       arithmeticType = rewriter.getIntegerType(type.getIntOrFloatBitWidth(),
                                                /*isSigned=*/false);
     }
-    if (arithmeticType != type) {
-      lhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
-                                                    lhs);
-      rhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
-                                                    rhs);
-    }
 
-    Value result = rewriter.template create<EmitCOp>(op.getLoc(),
-                                                     arithmeticType, lhs, rhs);
+    Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
+    Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
+
+    Value arithmeticResult = rewriter.template create<EmitCOp>(
+        op.getLoc(), arithmeticType, lhs, rhs);
+
+    Value result = adaptValueType(arithmeticResult, rewriter, type);
 
-    if (arithmeticType != type) {
-      result =
-          rewriter.template create<emitc::CastOp>(op.getLoc(), type, result);
-    }
     rewriter.replaceOp(op, result);
     return success();
   }
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index b2556bb6065d8..c3c9b4e6a1d3e 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -241,6 +241,12 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
        emitc::isSupportedFloatType(output) || isa<emitc::PointerType>(output)));
 }
 
+OpFoldResult emitc::CastOp::fold(FoldAdaptor adaptor) {
+  if (getOperand().getType() == getResult().getType())
+    return getOperand();
+  return nullptr;
+}
+
 //===----------------------------------------------------------------------===//
 // CallOpaqueOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index 71f1a6abd913b..607e5bf9b1a3b 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -466,6 +466,13 @@ func.func @arith_trunci(%arg0: i32) -> i8 {
   // CHECK: emitc.cast %[[Trunc]] : ui8 to i8
   %truncd = arith.trunci %arg0 : i32 to i8
 
+  // CHECK: %[[Const:.*]] = "emitc.constant"
+  // CHECK-SAME: value = 1
+  // CHECK-SAME: () -> i32
+  // CHECK: %[[AndOne:.*]] = emitc.bitwise_and %[[Arg0]], %[[Const]] : (i32, i32) -> i32
+  // CHECK: %[[Conv:.*]] = emitc.cast %[[AndOne]] : i32 to i1
+  %bool = arith.trunci %arg0 : i32 to i1
+
   return %truncd : i8
 }
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/95789


More information about the Mlir-commits mailing list