[Mlir-commits] [mlir] [mlir][emitc] Refactor ArithToEmitC: perform sign adaptation, type conversions / cast insertion in a single place (PR #95789)
Corentin Ferry
llvmlistbot at llvm.org
Mon Jun 17 07:02:21 PDT 2024
https://github.com/cferry-AMD created https://github.com/llvm/llvm-project/pull/95789
This PR lays the ground for a next PR that will use the newly introduced EmitC types.
When certain values have to be interpreted as signed/unsigned (e.g. for ops like `extsi` and `extui`), type conversions in EmitC may be required (from `int` to `unsigned int` for instance), and such conversions are not needed in `arith` ops that only take signless values as operands. Lowering of such ops now uses `adaptIntegralTypeSignedness` (to give out the type with the right signedness) and `adaptValueType` (to insert an explicit cast of a value to a given EmitC type).
>From 7f0ab5eda8580c6a1a7d26569ac61cae35189b5f Mon Sep 17 00:00:00 2001
From: Corentin Ferry <corentin.ferry at amd.com>
Date: Mon, 17 Jun 2024 14:13:17 +0100
Subject: [PATCH] Refactor ArithToEmitC: adaptIntegralTypeSignedness
---
mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 1 +
.../Conversion/ArithToEmitC/ArithToEmitC.cpp | 78 +++++++------------
mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 6 ++
.../ArithToEmitC/arith-to-emitc.mlir | 7 ++
4 files changed, 40 insertions(+), 52 deletions(-)
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index 452302c565139..25d1983ec583b 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -288,6 +288,7 @@ def EmitC_CastOp : EmitC_Op<"cast",
let arguments = (ins EmitCType:$source);
let results = (outs EmitCType:$dest);
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
+ let hasFolder = 1;
}
def EmitC_CmpOp : EmitC_BinaryOp<"cmp", [CExpression]> {
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index 74f0f61d04a1a..9214bc5b2c13e 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -270,19 +270,11 @@ class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
bool needsUnsigned = needsUnsignedCmp(op.getPredicate());
emitc::CmpPredicate pred = toEmitCPred(op.getPredicate());
- Type arithmeticType = type;
- if (type.isUnsignedInteger() != needsUnsigned) {
- arithmeticType = rewriter.getIntegerType(type.getIntOrFloatBitWidth(),
- /*isSigned=*/!needsUnsigned);
- }
- Value lhs = adaptor.getLhs();
- Value rhs = adaptor.getRhs();
- if (arithmeticType != type) {
- lhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
- lhs);
- rhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
- rhs);
- }
+
+ Type arithmeticType = adaptIntegralTypeSignedness(type, needsUnsigned);
+ Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
+ Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
+
rewriter.replaceOpWithNewOp<emitc::CmpOp>(op, op.getType(), pred, lhs, rhs);
return success();
}
@@ -328,37 +320,26 @@ class CastConversion : public OpConversionPattern<ArithOp> {
return success();
}
- bool isTruncation = operandType.getIntOrFloatBitWidth() >
- opReturnType.getIntOrFloatBitWidth();
+ bool isTruncation =
+ (isa<IntegerType>(operandType) && isa<IntegerType>(opReturnType) &&
+ operandType.getIntOrFloatBitWidth() >
+ opReturnType.getIntOrFloatBitWidth());
bool doUnsigned = castToUnsigned || isTruncation;
- Type castType = opReturnType;
- // If the op is a ui variant and the type wanted as
- // return type isn't unsigned, we need to issue an unsigned type to do
- // the conversion.
- if (castType.isUnsignedInteger() != doUnsigned) {
- castType = rewriter.getIntegerType(opReturnType.getIntOrFloatBitWidth(),
- /*isSigned=*/!doUnsigned);
- }
+ // Adapt the signedness of the result (bitwidth-preserving cast)
+ // This is needed e.g., if the return type is signless.
+ Type castDestType = adaptIntegralTypeSignedness(opReturnType, doUnsigned);
- Value actualOp = adaptor.getIn();
- // Adapt the signedness of the operand if necessary
- if (operandType.isUnsignedInteger() != doUnsigned) {
- Type correctSignednessType =
- rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(),
- /*isSigned=*/!doUnsigned);
- actualOp = rewriter.template create<emitc::CastOp>(
- op.getLoc(), correctSignednessType, actualOp);
- }
+ // Adapt the signedness of the operand (bitwidth-preserving cast)
+ Type castSrcType = adaptIntegralTypeSignedness(operandType, doUnsigned);
+ Value actualOp = adaptValueType(adaptor.getIn(), rewriter, castSrcType);
- auto result = rewriter.template create<emitc::CastOp>(op.getLoc(), castType,
- actualOp);
+ // Actual cast (may change bitwidth)
+ auto cast = rewriter.template create<emitc::CastOp>(op.getLoc(),
+ castDestType, actualOp);
// Cast to the expected output type
- if (castType != opReturnType) {
- result = rewriter.template create<emitc::CastOp>(op.getLoc(),
- opReturnType, result);
- }
+ auto result = adaptValueType(cast, rewriter, opReturnType);
rewriter.replaceOp(op, result);
return success();
@@ -410,8 +391,6 @@ class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
return rewriter.notifyMatchFailure(op, "i1 type is not implemented");
}
- Value lhs = adaptor.getLhs();
- Value rhs = adaptor.getRhs();
Type arithmeticType = type;
if ((type.isSignlessInteger() || type.isSignedInteger()) &&
!bitEnumContainsAll(op.getOverflowFlags(),
@@ -421,20 +400,15 @@ class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
arithmeticType = rewriter.getIntegerType(type.getIntOrFloatBitWidth(),
/*isSigned=*/false);
}
- if (arithmeticType != type) {
- lhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
- lhs);
- rhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
- rhs);
- }
- Value result = rewriter.template create<EmitCOp>(op.getLoc(),
- arithmeticType, lhs, rhs);
+ Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
+ Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
+
+ Value arithmeticResult = rewriter.template create<EmitCOp>(
+ op.getLoc(), arithmeticType, lhs, rhs);
+
+ Value result = adaptValueType(arithmeticResult, rewriter, type);
- if (arithmeticType != type) {
- result =
- rewriter.template create<emitc::CastOp>(op.getLoc(), type, result);
- }
rewriter.replaceOp(op, result);
return success();
}
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index b2556bb6065d8..c3c9b4e6a1d3e 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -241,6 +241,12 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
emitc::isSupportedFloatType(output) || isa<emitc::PointerType>(output)));
}
+OpFoldResult emitc::CastOp::fold(FoldAdaptor adaptor) {
+ if (getOperand().getType() == getResult().getType())
+ return getOperand();
+ return nullptr;
+}
+
//===----------------------------------------------------------------------===//
// CallOpaqueOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index 71f1a6abd913b..607e5bf9b1a3b 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -466,6 +466,13 @@ func.func @arith_trunci(%arg0: i32) -> i8 {
// CHECK: emitc.cast %[[Trunc]] : ui8 to i8
%truncd = arith.trunci %arg0 : i32 to i8
+ // CHECK: %[[Const:.*]] = "emitc.constant"
+ // CHECK-SAME: value = 1
+ // CHECK-SAME: () -> i32
+ // CHECK: %[[AndOne:.*]] = emitc.bitwise_and %[[Arg0]], %[[Const]] : (i32, i32) -> i32
+ // CHECK: %[[Conv:.*]] = emitc.cast %[[AndOne]] : i32 to i1
+ %bool = arith.trunci %arg0 : i32 to i1
+
return %truncd : i8
}
More information about the Mlir-commits
mailing list