[Mlir-commits] [mlir] [mlir][emitc] Lower arith.index_cast, arith.index_castui, arith.shli, arith.shrui, arith.shrsi (PR #95795)
Corentin Ferry
llvmlistbot at llvm.org
Wed Jun 19 00:23:46 PDT 2024
https://github.com/cferry-AMD updated https://github.com/llvm/llvm-project/pull/95795
>From 5ed5160e9e7fc46970b5360ed99989a044899eb3 Mon Sep 17 00:00:00 2001
From: Corentin Ferry <corentin.ferry at amd.com>
Date: Mon, 17 Jun 2024 14:43:53 +0100
Subject: [PATCH 1/4] Use new types in EmitC, lower index_cast
---
.../Conversion/ArithToEmitC/ArithToEmitC.cpp | 52 ++++++++---
.../Conversion/ArithToEmitC/CMakeLists.txt | 1 +
.../ArithToEmitC/arith-to-emitc.mlir | 87 +++++++++++++++++--
3 files changed, 120 insertions(+), 20 deletions(-)
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index 93717e3b02ef0..2a9784b727802 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
+#include "mlir/Dialect/EmitC/Transforms/TypeConversions.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Support/LogicalResult.h"
@@ -36,8 +37,11 @@ class ArithConstantOpConversionPattern
matchAndRewrite(arith::ConstantOp arithConst,
arith::ConstantOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<emitc::ConstantOp>(
- arithConst, arithConst.getType(), adaptor.getValue());
+ Type newTy = this->getTypeConverter()->convertType(arithConst.getType());
+ if (!newTy)
+ return rewriter.notifyMatchFailure(arithConst, "type conversion failed");
+ rewriter.replaceOpWithNewOp<emitc::ConstantOp>(arithConst, newTy,
+ adaptor.getValue());
return success();
}
};
@@ -52,6 +56,12 @@ Type adaptIntegralTypeSignedness(Type ty, bool needsUnsigned) {
return IntegerType::get(ty.getContext(), ty.getIntOrFloatBitWidth(),
signedness);
}
+ } else if (emitc::isPointerWideType(ty)) {
+ if (isa<emitc::SizeTType>(ty) != needsUnsigned) {
+ if (needsUnsigned)
+ return emitc::SizeTType::get(ty.getContext());
+ return emitc::PtrDiffTType::get(ty.getContext());
+ }
}
return ty;
}
@@ -264,8 +274,9 @@ class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
ConversionPatternRewriter &rewriter) const override {
Type type = adaptor.getLhs().getType();
- if (!isa_and_nonnull<IntegerType, IndexType>(type)) {
- return rewriter.notifyMatchFailure(op, "expected integer or index type");
+ if (type && !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) {
+ return rewriter.notifyMatchFailure(
+ op, "expected integer or size_t/ssize_t type");
}
bool needsUnsigned = needsUnsignedCmp(op.getPredicate());
@@ -318,8 +329,10 @@ class CastConversion : public OpConversionPattern<ArithOp> {
ConversionPatternRewriter &rewriter) const override {
Type opReturnType = this->getTypeConverter()->convertType(op.getType());
- if (!isa_and_nonnull<IntegerType>(opReturnType))
- return rewriter.notifyMatchFailure(op, "expected integer result type");
+ if (opReturnType && !(isa_and_nonnull<IntegerType>(opReturnType) ||
+ emitc::isPointerWideType(opReturnType)))
+ return rewriter.notifyMatchFailure(
+ op, "expected integer or size_t/ssize_t result type");
if (adaptor.getOperands().size() != 1) {
return rewriter.notifyMatchFailure(
@@ -327,8 +340,10 @@ class CastConversion : public OpConversionPattern<ArithOp> {
}
Type operandType = adaptor.getIn().getType();
- if (!isa_and_nonnull<IntegerType>(operandType))
- return rewriter.notifyMatchFailure(op, "expected integer operand type");
+ if (operandType && !(isa_and_nonnull<IntegerType>(operandType) ||
+ emitc::isPointerWideType(operandType)))
+ return rewriter.notifyMatchFailure(
+ op, "expected integer or size_t/ssize_t operand type");
// Signed (sign-extending) casts from i1 are not supported.
if (operandType.isInteger(1) && !castToUnsigned)
@@ -339,8 +354,11 @@ class CastConversion : public OpConversionPattern<ArithOp> {
// equivalent to (v != 0). Implementing as (bool)(v & 0x01) gives
// truncation.
if (opReturnType.isInteger(1)) {
+ Type attrType = (emitc::isPointerWideType(operandType))
+ ? rewriter.getIndexType()
+ : operandType;
auto constOne = rewriter.create<emitc::ConstantOp>(
- op.getLoc(), operandType, rewriter.getIntegerAttr(operandType, 1));
+ op.getLoc(), operandType, rewriter.getIntegerAttr(attrType, 1));
auto oneAndOperand = rewriter.create<emitc::BitwiseAndOp>(
op.getLoc(), operandType, adaptor.getIn(), constOne);
rewriter.replaceOpWithNewOp<emitc::CastOp>(op, opReturnType,
@@ -393,7 +411,11 @@ class ArithOpConversion final : public OpConversionPattern<ArithOp> {
matchAndRewrite(ArithOp arithOp, typename ArithOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- rewriter.template replaceOpWithNewOp<EmitCOp>(arithOp, arithOp.getType(),
+ Type newTy = this->getTypeConverter()->convertType(arithOp.getType());
+ if (!newTy)
+ return rewriter.notifyMatchFailure(arithOp,
+ "converting result type failed");
+ rewriter.template replaceOpWithNewOp<EmitCOp>(arithOp, newTy,
adaptor.getOperands());
return success();
@@ -410,8 +432,10 @@ class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
ConversionPatternRewriter &rewriter) const override {
Type type = this->getTypeConverter()->convertType(op.getType());
- if (!isa_and_nonnull<IntegerType, IndexType>(type)) {
- return rewriter.notifyMatchFailure(op, "expected integer type");
+ if (type && !(isa_and_nonnull<IntegerType>(type) ||
+ emitc::isPointerWideType(type))) {
+ return rewriter.notifyMatchFailure(
+ op, "expected integer or size_t/ssize_t/ptrdiff_t type");
}
if (type.isInteger(1)) {
@@ -606,6 +630,8 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
RewritePatternSet &patterns) {
MLIRContext *ctx = patterns.getContext();
+ mlir::populateEmitCSizeTTypeConversions(typeConverter);
+
// clang-format off
patterns.add<
ArithConstantOpConversionPattern,
@@ -629,6 +655,8 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
UnsignedCastConversion<arith::TruncIOp>,
SignedCastConversion<arith::ExtSIOp>,
UnsignedCastConversion<arith::ExtUIOp>,
+ SignedCastConversion<arith::IndexCastOp>,
+ UnsignedCastConversion<arith::IndexCastUIOp>,
ItoFCastOpConversion<arith::SIToFPOp>,
ItoFCastOpConversion<arith::UIToFPOp>,
FtoICastOpConversion<arith::FPToSIOp>,
diff --git a/mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt b/mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt
index a3784f47c3bc2..730a4b341673d 100644
--- a/mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt
+++ b/mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_conversion_library(MLIRArithToEmitC
LINK_LIBS PUBLIC
MLIRArithDialect
MLIREmitCDialect
+ MLIREmitCTransforms
MLIRPass
MLIRTransformUtils
)
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index 0289b7dc0728f..89a57d1d7cebc 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -3,7 +3,8 @@
// CHECK-LABEL: arith_constants
func.func @arith_constants() {
// CHECK: emitc.constant
- // CHECK-SAME: value = 0 : index
+ // CHECK-SAME: value = 0
+ // CHECK-SAME: () -> !emitc.size_t
%c_index = arith.constant 0 : index
// CHECK: emitc.constant
// CHECK-SAME: value = 0 : i32
@@ -75,13 +76,18 @@ func.func @arith_integer_ops_signed_nsw(%arg0: i32, %arg1: i32) {
// -----
// CHECK-LABEL: arith_index
-func.func @arith_index(%arg0: index, %arg1: index) {
- // CHECK: emitc.add %arg0, %arg1 : (index, index) -> index
- %0 = arith.addi %arg0, %arg1 : index
- // CHECK: emitc.sub %arg0, %arg1 : (index, index) -> index
- %1 = arith.subi %arg0, %arg1 : index
- // CHECK: emitc.mul %arg0, %arg1 : (index, index) -> index
- %2 = arith.muli %arg0, %arg1 : index
+func.func @arith_index(%arg0: i32, %arg1: i32) {
+ // CHECK: %[[CST0:.*]] = emitc.cast %{{.*}} : {{.*}} to !emitc.size_t
+ %cst0 = arith.index_cast %arg0 : i32 to index
+ // CHECK: %[[CST1:.*]] = emitc.cast %{{.*}} : {{.*}} to !emitc.size_t
+ %cst1 = arith.index_cast %arg1 : i32 to index
+
+ // CHECK: emitc.add %[[CST0]], %[[CST1]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
+ %0 = arith.addi %cst0, %cst1 : index
+ // CHECK: emitc.sub %[[CST0]], %[[CST1]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
+ %1 = arith.subi %cst0, %cst1 : index
+ // CHECK: emitc.mul %[[CST0]], %[[CST1]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
+ %2 = arith.muli %cst0, %cst1 : index
return
}
@@ -420,6 +426,27 @@ func.func @arith_cmpi_predicates(%arg0: i32, %arg1: i32) {
return
}
+func.func @arith_cmpi_index(%arg0: i32, %arg1: i32) -> i1 {
+ // CHECK-LABEL: arith_cmpi_index
+
+ // CHECK: %[[Cst0:.*]] = emitc.cast %{{.*}} : {{.*}} to !emitc.size_t
+ %idx0 = arith.index_cast %arg0 : i32 to index
+ // CHECK: %[[Cst1:.*]] = emitc.cast %{{.*}} : {{.*}} to !emitc.size_t
+ %idx1 = arith.index_cast %arg0 : i32 to index
+
+ // CHECK-DAG: [[ULT:[^ ]*]] = emitc.cmp lt, %[[Cst0]], %[[Cst1]] : (!emitc.size_t, !emitc.size_t) -> i1
+ %ult = arith.cmpi ult, %idx0, %idx1 : index
+
+ // CHECK-DAG: %[[CastArg0:[^ ]*]] = emitc.cast %[[Cst0]] : !emitc.size_t to !emitc.ptrdiff_t
+ // CHECK-DAG: %[[CastArg1:[^ ]*]] = emitc.cast %[[Cst1]] : !emitc.size_t to !emitc.ptrdiff_t
+ // CHECK-DAG: %[[SLT:[^ ]*]] = emitc.cmp lt, %[[CastArg0]], %[[CastArg1]] : (!emitc.ptrdiff_t, !emitc.ptrdiff_t) -> i1
+ %slt = arith.cmpi slt, %idx0, %idx1 : index
+
+ // CHECK: return %[[SLT]]
+ return %slt: i1
+}
+
+
// -----
func.func @arith_negf(%arg0: f32) -> f32 {
@@ -536,3 +563,47 @@ func.func @arith_extui_i1_to_i32(%arg0: i1) {
%idx = arith.extui %arg0 : i1 to i32
return
}
+
+// -----
+
+func.func @arith_index_cast(%arg0: i32) -> i32 {
+ // CHECK-LABEL: arith_index_cast
+ // CHECK-SAME: (%[[Arg0:[^ ]*]]: i32)
+ // CHECK: %[[Conv0:.*]] = emitc.cast %[[Arg0]] : i32 to !emitc.ptrdiff_t
+ // CHECK: %[[Conv1:.*]] = emitc.cast %[[Conv0]] : !emitc.ptrdiff_t to !emitc.size_t
+ %idx = arith.index_cast %arg0 : i32 to index
+ // CHECK: %[[Conv2:.*]] = emitc.cast %[[Conv1]] : !emitc.size_t to !emitc.ptrdiff_t
+ // CHECK: %[[Conv3:.*]] = emitc.cast %[[Conv2]] : !emitc.ptrdiff_t to i32
+ %int = arith.index_cast %idx : index to i32
+
+ // CHECK: %[[Const:.*]] = "emitc.constant"
+ // CHECK-SAME: value = 1
+ // CHECK-SAME: () -> !emitc.size_t
+ // CHECK: %[[AndOne:.*]] = emitc.bitwise_and %[[Conv1]], %[[Const]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
+ // CHECK: %[[Conv4:.*]] = emitc.cast %[[AndOne]] : !emitc.size_t to i1
+ %bool = arith.index_cast %idx : index to i1
+
+ return %int : i32
+}
+
+// -----
+
+func.func @arith_index_castui(%arg0: i32) -> i32 {
+ // CHECK-LABEL: arith_index_castui
+ // CHECK-SAME: (%[[Arg0:[^ ]*]]: i32)
+ // CHECK: %[[Conv0:.*]] = emitc.cast %[[Arg0]] : i32 to ui32
+ // CHECK: %[[Conv1:.*]] = emitc.cast %[[Conv0]] : ui32 to !emitc.size_t
+ %idx = arith.index_castui %arg0 : i32 to index
+ // CHECK: %[[Conv2:.*]] = emitc.cast %[[Conv1]] : !emitc.size_t to ui32
+ // CHECK: %[[Conv3:.*]] = emitc.cast %[[Conv2]] : ui32 to i32
+ %int = arith.index_castui %idx : index to i32
+
+ // CHECK: %[[Const:.*]] = "emitc.constant"
+ // CHECK-SAME: value = 1
+ // CHECK-SAME: () -> !emitc.size_t
+ // CHECK: %[[AndOne:.*]] = emitc.bitwise_and %[[Conv1]], %[[Const]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
+ // CHECK: %[[Conv4:.*]] = emitc.cast %[[AndOne]] : !emitc.size_t to i1
+ %bool = arith.index_castui %idx : index to i1
+
+ return %int : i32
+}
>From a4a10e9c24284560d8f2932d04349b2a14aa9c59 Mon Sep 17 00:00:00 2001
From: Corentin Ferry <corentin.ferry at amd.com>
Date: Wed, 29 May 2024 10:48:49 +0200
Subject: [PATCH 2/4] Add shift operations
---
.../Conversion/ArithToEmitC/ArithToEmitC.cpp | 88 ++++++++++++++
.../arith-to-emitc-unsupported.mlir | 24 ++++
.../ArithToEmitC/arith-to-emitc.mlir | 110 ++++++++++++++++++
3 files changed, 222 insertions(+)
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index 2a9784b727802..c16dcf65868ff 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -18,6 +18,7 @@
#include "mlir/Dialect/EmitC/Transforms/TypeConversions.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Region.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -506,6 +507,90 @@ class BitwiseOpConversion : public OpConversionPattern<ArithOp> {
}
};
+template <typename ArithOp, typename EmitCOp, bool isUnsignedOp>
+class ShiftOpConversion : public OpConversionPattern<ArithOp> {
+public:
+ using OpConversionPattern<ArithOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ Type type = this->getTypeConverter()->convertType(op.getType());
+ if (type && !(isa_and_nonnull<IntegerType>(type) ||
+ emitc::isPointerWideType(type))) {
+ return rewriter.notifyMatchFailure(
+ op, "expected integer or size_t/ssize_t type");
+ }
+
+ if (type.isInteger(1)) {
+ return rewriter.notifyMatchFailure(op, "i1 type is not implemented");
+ }
+
+ Type arithmeticType = adaptIntegralTypeSignedness(type, isUnsignedOp);
+
+ Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
+ // Shift amount interpreted as unsigned per Arith dialect spec.
+ Type rhsType = adaptIntegralTypeSignedness(adaptor.getRhs().getType(),
+ /*needsUnsigned=*/true);
+ Value rhs = adaptValueType(adaptor.getRhs(), rewriter, rhsType);
+
+ // Add a runtime check for overflow
+ Value width;
+ if (emitc::isPointerWideType(type)) {
+ Value eight = rewriter.create<emitc::ConstantOp>(
+ op.getLoc(), rhsType, rewriter.getIndexAttr(8));
+ emitc::CallOpaqueOp sizeOfCall = rewriter.create<emitc::CallOpaqueOp>(
+ op.getLoc(), rhsType, "sizeof", SmallVector<Value, 1>({eight}));
+ width = rewriter.create<emitc::MulOp>(op.getLoc(), rhsType, eight,
+ sizeOfCall.getResult(0));
+ } else {
+ width = rewriter.create<emitc::ConstantOp>(
+ op.getLoc(), rhsType,
+ rewriter.getIntegerAttr(rhsType, type.getIntOrFloatBitWidth()));
+ }
+
+ Value excessCheck = rewriter.create<emitc::CmpOp>(
+ op.getLoc(), rewriter.getI1Type(), emitc::CmpPredicate::lt, rhs, width);
+
+ // Any concrete value is a valid refinement of poison.
+ Value poison = rewriter.create<emitc::ConstantOp>(
+ op.getLoc(), arithmeticType,
+ (isa<IntegerType>(arithmeticType)
+ ? rewriter.getIntegerAttr(arithmeticType, 0)
+ : rewriter.getIndexAttr(0)));
+
+ emitc::ExpressionOp ternary = rewriter.create<emitc::ExpressionOp>(
+ op.getLoc(), arithmeticType, /*do_not_inline=*/false);
+ Block &bodyBlock = ternary.getBodyRegion().emplaceBlock();
+ auto currentPoint = rewriter.getInsertionPoint();
+ rewriter.setInsertionPointToStart(&bodyBlock);
+ Value arithmeticResult =
+ rewriter.create<EmitCOp>(op.getLoc(), arithmeticType, lhs, rhs);
+ Value resultOrPoison = rewriter.create<emitc::ConditionalOp>(
+ op.getLoc(), arithmeticType, excessCheck, arithmeticResult, poison);
+ rewriter.create<emitc::YieldOp>(op.getLoc(), resultOrPoison);
+ rewriter.setInsertionPoint(op->getBlock(), currentPoint);
+
+ Value result = adaptValueType(ternary, rewriter, type);
+
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
+template <typename ArithOp, typename EmitCOp>
+class SignedShiftOpConversion final
+ : public ShiftOpConversion<ArithOp, EmitCOp, false> {
+ using ShiftOpConversion<ArithOp, EmitCOp, false>::ShiftOpConversion;
+};
+
+template <typename ArithOp, typename EmitCOp>
+class UnsignedShiftOpConversion final
+ : public ShiftOpConversion<ArithOp, EmitCOp, true> {
+ using ShiftOpConversion<ArithOp, EmitCOp, true>::ShiftOpConversion;
+};
+
class SelectOpConversion : public OpConversionPattern<arith::SelectOp> {
public:
using OpConversionPattern<arith::SelectOp>::OpConversionPattern;
@@ -647,6 +732,9 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
BitwiseOpConversion<arith::AndIOp, emitc::BitwiseAndOp>,
BitwiseOpConversion<arith::OrIOp, emitc::BitwiseOrOp>,
BitwiseOpConversion<arith::XOrIOp, emitc::BitwiseXorOp>,
+ UnsignedShiftOpConversion<arith::ShLIOp, emitc::BitwiseLeftShiftOp>,
+ SignedShiftOpConversion<arith::ShRSIOp, emitc::BitwiseRightShiftOp>,
+ UnsignedShiftOpConversion<arith::ShRUIOp, emitc::BitwiseRightShiftOp>,
CmpFOpConversion,
CmpIOpConversion,
NegFOpConversion,
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
index caef04052aa8c..766ad4039335e 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
@@ -110,3 +110,27 @@ func.func @arith_extsi_i1_to_i32(%arg0: i1) {
%idx = arith.extsi %arg0 : i1 to i32
return
}
+
+// -----
+
+func.func @arith_shli_i1(%arg0: i1, %arg1: i1) {
+ // expected-error @+1 {{failed to legalize operation 'arith.shli'}}
+ %shli = arith.shli %arg0, %arg1 : i1
+ return
+}
+
+// -----
+
+func.func @arith_shrsi_i1(%arg0: i1, %arg1: i1) {
+ // expected-error @+1 {{failed to legalize operation 'arith.shrsi'}}
+ %shrsi = arith.shrsi %arg0, %arg1 : i1
+ return
+}
+
+// -----
+
+func.func @arith_shrui_i1(%arg0: i1, %arg1: i1) {
+ // expected-error @+1 {{failed to legalize operation 'arith.shrui'}}
+ %shrui = arith.shrui %arg0, %arg1 : i1
+ return
+}
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index 89a57d1d7cebc..ac4bc609f0f42 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -144,6 +144,116 @@ func.func @arith_signed_integer_div_rem(%arg0: i32, %arg1: i32) {
// -----
+// CHECK-LABEL: arith_shift_left
+// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32
+func.func @arith_shift_left(%arg0: i32, %arg1: i32) {
+ // CHECK-DAG: %[[C1:[^ ]*]] = emitc.cast %[[ARG0]] : i32 to ui32
+ // CHECK-DAG: %[[C2:[^ ]*]] = emitc.cast %[[ARG1]] : i32 to ui32
+ // CHECK-DAG: %[[SizeConstant:[^ ]*]] = "emitc.constant"{{.*}}value = 32
+ // CHECK-DAG: %[[CmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[C2]], %[[SizeConstant]] : (ui32, ui32) -> i1
+ // CHECK-DAG: %[[Zero:[^ ]*]] = "emitc.constant"{{.*}}value = 0
+ // CHECK: %[[ShiftRes:[^ ]*]] = emitc.expression : ui32
+ // CHECK: %[[SHL:[^ ]*]] = emitc.bitwise_left_shift %[[C1]], %[[C2]] : (ui32, ui32) -> ui32
+ // CHECK: %[[Ternary:[^ ]*]] = emitc.conditional %[[CmpNoExcess]], %[[SHL]], %[[Zero]] : ui32
+ // CHECK: emitc.yield %[[Ternary]] : ui32
+ // CHECK: emitc.cast %[[ShiftRes]] : ui32 to i32
+ %1 = arith.shli %arg0, %arg1 : i32
+ return
+}
+
+// -----
+
+// CHECK-LABEL: arith_shift_right
+// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32
+func.func @arith_shift_right(%arg0: i32, %arg1: i32) {
+ // CHECK-DAG: %[[C1:[^ ]*]] = emitc.cast %[[ARG0]] : i32 to ui32
+ // CHECK-DAG: %[[C2:[^ ]*]] = emitc.cast %[[ARG1]] : i32 to ui32
+ // CHECK-DAG: %[[SizeConstant:[^ ]*]] = "emitc.constant"{{.*}}value = 32{{.*}}ui32
+ // CHECK-DAG: %[[CmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[C2]], %[[SizeConstant]] : (ui32, ui32) -> i1
+ // CHECK-DAG: %[[Zero:[^ ]*]] = "emitc.constant"{{.*}}value = 0{{.*}}ui32
+ // CHECK: %[[ShiftRes:[^ ]*]] = emitc.expression : ui32
+ // CHECK: %[[SHR:[^ ]*]] = emitc.bitwise_right_shift %[[C1]], %[[C2]] : (ui32, ui32) -> ui32
+ // CHECK: %[[Ternary:[^ ]*]] = emitc.conditional %[[CmpNoExcess]], %[[SHR]], %[[Zero]] : ui32
+ // CHECK: emitc.yield %[[Ternary]] : ui32
+ // CHECK: emitc.cast %[[ShiftRes]] : ui32 to i32
+ %2 = arith.shrui %arg0, %arg1 : i32
+
+ // CHECK-DAG: %[[SC2:[^ ]*]] = emitc.cast %[[ARG1]] : i32 to ui32
+ // CHECK-DAG: %[[SSizeConstant:[^ ]*]] = "emitc.constant"{{.*}}value = 32{{.*}}ui32
+ // CHECK-DAG: %[[SCmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[SC2]], %[[SSizeConstant]] : (ui32, ui32) -> i1
+ // CHECK-DAG: %[[SZero:[^ ]*]] = "emitc.constant"{{.*}}value = 0{{.*}}i32
+ // CHECK: %[[SShiftRes:[^ ]*]] = emitc.expression : i32
+ // CHECK: %[[SHRSI:[^ ]*]] = emitc.bitwise_right_shift %[[ARG0]], %[[SC2]] : (i32, ui32) -> i32
+ // CHECK: %[[STernary:[^ ]*]] = emitc.conditional %[[SCmpNoExcess]], %[[SHRSI]], %[[SZero]] : i32
+ // CHECK: emitc.yield %[[STernary]] : i32
+ %3 = arith.shrsi %arg0, %arg1 : i32
+
+ return
+}
+
+// -----
+
+// CHECK-LABEL: arith_shift_left_index
+// CHECK-SAME: %[[AMOUNT:.*]]: i32
+func.func @arith_shift_left_index(%amount: i32) {
+ %cst0 = "arith.constant"() {value = 42 : index} : () -> (index)
+ %cast1 = arith.index_cast %amount : i32 to index
+ // CHECK-DAG: %[[C1:[^ ]*]] = "emitc.constant"(){{.*}}value = 42{{.*}}!emitc.size_t
+ // CHECK-DAG: %[[Cast1:[^ ]*]] = emitc.cast %[[AMOUNT]] : i32 to !emitc.ssize_t
+ // CHECK-DAG: %[[AmountIdx:[^ ]*]] = emitc.cast %[[Cast1]] : !emitc.ssize_t to !emitc.size_t
+ // CHECK-DAG: %[[Byte:[^ ]*]] = "emitc.constant"{{.*}}value = 8{{.*}}index
+ // CHECK-DAG: %[[SizeOf:[^ ]*]] = emitc.call_opaque "sizeof"(%[[Byte]]) : (!emitc.size_t) -> !emitc.size_t
+ // CHECK-DAG: %[[SizeConstant:[^ ]*]] = emitc.mul %[[Byte]], %[[SizeOf]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
+ // CHECK-DAG: %[[CmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[AmountIdx]], %[[SizeConstant]] : (!emitc.size_t, !emitc.size_t) -> i1
+ // CHECK-DAG: %[[Zero:[^ ]*]] = "emitc.constant"{{.*}}value = 0
+ // CHECK: %[[ShiftRes:[^ ]*]] = emitc.expression : !emitc.size_t
+ // CHECK: %[[SHL:[^ ]*]] = emitc.bitwise_left_shift %[[C1]], %[[AmountIdx]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
+ // CHECK: %[[Ternary:[^ ]*]] = emitc.conditional %[[CmpNoExcess]], %[[SHL]], %[[Zero]] : !emitc.size_t
+ // CHECK: emitc.yield %[[Ternary]] : !emitc.size_t
+ %1 = arith.shli %cst0, %cast1 : index
+ return
+}
+
+// -----
+
+// CHECK-LABEL: arith_shift_right_index
+// CHECK-SAME: %[[AMOUNT:.*]]: i32
+func.func @arith_shift_right_index(%amount: i32) {
+ // CHECK-DAG: %[[C1:[^ ]*]] = "emitc.constant"(){{.*}}value = 42{{.*}}!emitc.size_t
+ // CHECK-DAG: %[[Cast1:[^ ]*]] = emitc.cast %[[AMOUNT]] : i32 to !emitc.ssize_t
+ // CHECK-DAG: %[[AmountIdx:[^ ]*]] = emitc.cast %[[Cast1]] : !emitc.ssize_t to !emitc.size_t
+ %arg0 = "arith.constant"() {value = 42 : index} : () -> (index)
+ %arg1 = arith.index_cast %amount : i32 to index
+
+ // CHECK-DAG: %[[Byte:[^ ]*]] = "emitc.constant"{{.*}}value = 8{{.*}}index
+ // CHECK-DAG: %[[SizeOf:[^ ]*]] = emitc.call_opaque "sizeof"(%[[Byte]]) : (!emitc.size_t) -> !emitc.size_t
+ // CHECK-DAG: %[[SizeConstant:[^ ]*]] = emitc.mul %[[Byte]], %[[SizeOf]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
+ // CHECK-DAG: %[[CmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[AmountIdx]], %[[SizeConstant]] : (!emitc.size_t, !emitc.size_t) -> i1
+ // CHECK-DAG: %[[Zero:[^ ]*]] = "emitc.constant"{{.*}}value = 0{{.*}}!emitc.size_t
+ // CHECK: %[[ShiftRes:[^ ]*]] = emitc.expression : !emitc.size_t
+ // CHECK: %[[SHR:[^ ]*]] = emitc.bitwise_right_shift %[[C1]], %[[AmountIdx]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
+ // CHECK: %[[Ternary:[^ ]*]] = emitc.conditional %[[CmpNoExcess]], %[[SHR]], %[[Zero]] : !emitc.size_t
+ // CHECK: emitc.yield %[[Ternary]] : !emitc.size_t
+ %2 = arith.shrui %arg0, %arg1 : index
+
+ // CHECK-DAG: %[[SC1:[^ ]*]] = emitc.cast %[[C1]] : !emitc.size_t to !emitc.ssize_t
+ // CHECK-DAG: %[[SByte:[^ ]*]] = "emitc.constant"{{.*}}value = 8{{.*}}index{{.*}}!emitc.size_t
+ // CHECK-DAG: %[[SSizeOf:[^ ]*]] = emitc.call_opaque "sizeof"(%[[SByte]]) : (!emitc.size_t) -> !emitc.size_t
+ // CHECK-DAG: %[[SSizeConstant:[^ ]*]] = emitc.mul %[[SByte]], %[[SSizeOf]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
+ // CHECK-DAG: %[[SCmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[AmountIdx]], %[[SSizeConstant]] : (!emitc.size_t, !emitc.size_t) -> i1
+ // CHECK-DAG: %[[SZero:[^ ]*]] = "emitc.constant"{{.*}}value = 0{{.*}}!emitc.ssize_t
+ // CHECK: %[[SShiftRes:[^ ]*]] = emitc.expression : !emitc.ssize_t
+ // CHECK: %[[SHRSI:[^ ]*]] = emitc.bitwise_right_shift %[[SC1]], %[[AmountIdx]] : (!emitc.ssize_t, !emitc.size_t) -> !emitc.ssize_t
+ // CHECK: %[[STernary:[^ ]*]] = emitc.conditional %[[SCmpNoExcess]], %[[SHRSI]], %[[SZero]] : !emitc.ssize_t
+ // CHECK: emitc.yield %[[STernary]] : !emitc.ssize_t
+ // CHECK: emitc.cast %[[SShiftRes]] : !emitc.ssize_t to !emitc.size_t
+ %3 = arith.shrsi %arg0, %arg1 : index
+
+ return
+}
+
+// -----
+
func.func @arith_select(%arg0: i1, %arg1: tensor<8xi32>, %arg2: tensor<8xi32>) -> () {
// CHECK: [[V0:[^ ]*]] = emitc.conditional %arg0, %arg1, %arg2 : tensor<8xi32>
%0 = arith.select %arg0, %arg1, %arg2 : i1, tensor<8xi32>
>From 027b90710647579bb3835fd80f9596afe538baa0 Mon Sep 17 00:00:00 2001
From: Corentin Ferry <corentin.ferry at amd.com>
Date: Mon, 17 Jun 2024 15:25:11 +0100
Subject: [PATCH 3/4] Fix type mismatches, possibly null operands
---
.../Conversion/ArithToEmitC/ArithToEmitC.cpp | 18 +++++++--------
.../ArithToEmitC/arith-to-emitc.mlir | 22 +++++++++----------
2 files changed, 20 insertions(+), 20 deletions(-)
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index c16dcf65868ff..84cc29aaffe50 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -275,9 +275,9 @@ class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
ConversionPatternRewriter &rewriter) const override {
Type type = adaptor.getLhs().getType();
- if (type && !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) {
+ if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) {
return rewriter.notifyMatchFailure(
- op, "expected integer or size_t/ssize_t type");
+ op, "expected integer or size_t/ssize_t/ptrdiff_t type");
}
bool needsUnsigned = needsUnsignedCmp(op.getPredicate());
@@ -330,10 +330,10 @@ class CastConversion : public OpConversionPattern<ArithOp> {
ConversionPatternRewriter &rewriter) const override {
Type opReturnType = this->getTypeConverter()->convertType(op.getType());
- if (opReturnType && !(isa_and_nonnull<IntegerType>(opReturnType) ||
+ if (!opReturnType || !(isa<IntegerType>(opReturnType) ||
emitc::isPointerWideType(opReturnType)))
return rewriter.notifyMatchFailure(
- op, "expected integer or size_t/ssize_t result type");
+ op, "expected integer or size_t/ssize_t/ptrdiff_t result type");
if (adaptor.getOperands().size() != 1) {
return rewriter.notifyMatchFailure(
@@ -341,10 +341,10 @@ class CastConversion : public OpConversionPattern<ArithOp> {
}
Type operandType = adaptor.getIn().getType();
- if (operandType && !(isa_and_nonnull<IntegerType>(operandType) ||
+ if (!operandType || !(isa<IntegerType>(operandType) ||
emitc::isPointerWideType(operandType)))
return rewriter.notifyMatchFailure(
- op, "expected integer or size_t/ssize_t operand type");
+ op, "expected integer or size_t/ssize_t/ptrdiff_t operand type");
// Signed (sign-extending) casts from i1 are not supported.
if (operandType.isInteger(1) && !castToUnsigned)
@@ -433,7 +433,7 @@ class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
ConversionPatternRewriter &rewriter) const override {
Type type = this->getTypeConverter()->convertType(op.getType());
- if (type && !(isa_and_nonnull<IntegerType>(type) ||
+ if (!type || !(isa_and_nonnull<IntegerType>(type) ||
emitc::isPointerWideType(type))) {
return rewriter.notifyMatchFailure(
op, "expected integer or size_t/ssize_t/ptrdiff_t type");
@@ -517,10 +517,10 @@ class ShiftOpConversion : public OpConversionPattern<ArithOp> {
ConversionPatternRewriter &rewriter) const override {
Type type = this->getTypeConverter()->convertType(op.getType());
- if (type && !(isa_and_nonnull<IntegerType>(type) ||
+ if (!type || !(isa_and_nonnull<IntegerType>(type) ||
emitc::isPointerWideType(type))) {
return rewriter.notifyMatchFailure(
- op, "expected integer or size_t/ssize_t type");
+ op, "expected integer or size_t/ssize_t/ptrdiff_t type");
}
if (type.isInteger(1)) {
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index ac4bc609f0f42..858ccd1171445 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -199,8 +199,8 @@ func.func @arith_shift_left_index(%amount: i32) {
%cst0 = "arith.constant"() {value = 42 : index} : () -> (index)
%cast1 = arith.index_cast %amount : i32 to index
// CHECK-DAG: %[[C1:[^ ]*]] = "emitc.constant"(){{.*}}value = 42{{.*}}!emitc.size_t
- // CHECK-DAG: %[[Cast1:[^ ]*]] = emitc.cast %[[AMOUNT]] : i32 to !emitc.ssize_t
- // CHECK-DAG: %[[AmountIdx:[^ ]*]] = emitc.cast %[[Cast1]] : !emitc.ssize_t to !emitc.size_t
+ // CHECK-DAG: %[[Cast1:[^ ]*]] = emitc.cast %[[AMOUNT]] : i32 to !emitc.ptrdiff_t
+ // CHECK-DAG: %[[AmountIdx:[^ ]*]] = emitc.cast %[[Cast1]] : !emitc.ptrdiff_t to !emitc.size_t
// CHECK-DAG: %[[Byte:[^ ]*]] = "emitc.constant"{{.*}}value = 8{{.*}}index
// CHECK-DAG: %[[SizeOf:[^ ]*]] = emitc.call_opaque "sizeof"(%[[Byte]]) : (!emitc.size_t) -> !emitc.size_t
// CHECK-DAG: %[[SizeConstant:[^ ]*]] = emitc.mul %[[Byte]], %[[SizeOf]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
@@ -220,8 +220,8 @@ func.func @arith_shift_left_index(%amount: i32) {
// CHECK-SAME: %[[AMOUNT:.*]]: i32
func.func @arith_shift_right_index(%amount: i32) {
// CHECK-DAG: %[[C1:[^ ]*]] = "emitc.constant"(){{.*}}value = 42{{.*}}!emitc.size_t
- // CHECK-DAG: %[[Cast1:[^ ]*]] = emitc.cast %[[AMOUNT]] : i32 to !emitc.ssize_t
- // CHECK-DAG: %[[AmountIdx:[^ ]*]] = emitc.cast %[[Cast1]] : !emitc.ssize_t to !emitc.size_t
+ // CHECK-DAG: %[[Cast1:[^ ]*]] = emitc.cast %[[AMOUNT]] : i32 to !emitc.ptrdiff_t
+ // CHECK-DAG: %[[AmountIdx:[^ ]*]] = emitc.cast %[[Cast1]] : !emitc.ptrdiff_t to !emitc.size_t
%arg0 = "arith.constant"() {value = 42 : index} : () -> (index)
%arg1 = arith.index_cast %amount : i32 to index
@@ -236,17 +236,17 @@ func.func @arith_shift_right_index(%amount: i32) {
// CHECK: emitc.yield %[[Ternary]] : !emitc.size_t
%2 = arith.shrui %arg0, %arg1 : index
- // CHECK-DAG: %[[SC1:[^ ]*]] = emitc.cast %[[C1]] : !emitc.size_t to !emitc.ssize_t
+ // CHECK-DAG: %[[SC1:[^ ]*]] = emitc.cast %[[C1]] : !emitc.size_t to !emitc.ptrdiff_t
// CHECK-DAG: %[[SByte:[^ ]*]] = "emitc.constant"{{.*}}value = 8{{.*}}index{{.*}}!emitc.size_t
// CHECK-DAG: %[[SSizeOf:[^ ]*]] = emitc.call_opaque "sizeof"(%[[SByte]]) : (!emitc.size_t) -> !emitc.size_t
// CHECK-DAG: %[[SSizeConstant:[^ ]*]] = emitc.mul %[[SByte]], %[[SSizeOf]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
// CHECK-DAG: %[[SCmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[AmountIdx]], %[[SSizeConstant]] : (!emitc.size_t, !emitc.size_t) -> i1
- // CHECK-DAG: %[[SZero:[^ ]*]] = "emitc.constant"{{.*}}value = 0{{.*}}!emitc.ssize_t
- // CHECK: %[[SShiftRes:[^ ]*]] = emitc.expression : !emitc.ssize_t
- // CHECK: %[[SHRSI:[^ ]*]] = emitc.bitwise_right_shift %[[SC1]], %[[AmountIdx]] : (!emitc.ssize_t, !emitc.size_t) -> !emitc.ssize_t
- // CHECK: %[[STernary:[^ ]*]] = emitc.conditional %[[SCmpNoExcess]], %[[SHRSI]], %[[SZero]] : !emitc.ssize_t
- // CHECK: emitc.yield %[[STernary]] : !emitc.ssize_t
- // CHECK: emitc.cast %[[SShiftRes]] : !emitc.ssize_t to !emitc.size_t
+ // CHECK-DAG: %[[SZero:[^ ]*]] = "emitc.constant"{{.*}}value = 0{{.*}}!emitc.ptrdiff_t
+ // CHECK: %[[SShiftRes:[^ ]*]] = emitc.expression : !emitc.ptrdiff_t
+ // CHECK: %[[SHRSI:[^ ]*]] = emitc.bitwise_right_shift %[[SC1]], %[[AmountIdx]] : (!emitc.ptrdiff_t, !emitc.size_t) -> !emitc.ptrdiff_t
+ // CHECK: %[[STernary:[^ ]*]] = emitc.conditional %[[SCmpNoExcess]], %[[SHRSI]], %[[SZero]] : !emitc.ptrdiff_t
+ // CHECK: emitc.yield %[[STernary]] : !emitc.ptrdiff_t
+ // CHECK: emitc.cast %[[SShiftRes]] : !emitc.ptrdiff_t to !emitc.size_t
%3 = arith.shrsi %arg0, %arg1 : index
return
>From 334a9163d6f15857f67d829b711fb26ea743181e Mon Sep 17 00:00:00 2001
From: Corentin Ferry <corentin.ferry at amd.com>
Date: Mon, 17 Jun 2024 15:34:12 +0100
Subject: [PATCH 4/4] Fix clang-format
---
mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index 84cc29aaffe50..b0c9d083ddd88 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -331,7 +331,7 @@ class CastConversion : public OpConversionPattern<ArithOp> {
Type opReturnType = this->getTypeConverter()->convertType(op.getType());
if (!opReturnType || !(isa<IntegerType>(opReturnType) ||
- emitc::isPointerWideType(opReturnType)))
+ emitc::isPointerWideType(opReturnType)))
return rewriter.notifyMatchFailure(
op, "expected integer or size_t/ssize_t/ptrdiff_t result type");
@@ -342,7 +342,7 @@ class CastConversion : public OpConversionPattern<ArithOp> {
Type operandType = adaptor.getIn().getType();
if (!operandType || !(isa<IntegerType>(operandType) ||
- emitc::isPointerWideType(operandType)))
+ emitc::isPointerWideType(operandType)))
return rewriter.notifyMatchFailure(
op, "expected integer or size_t/ssize_t/ptrdiff_t operand type");
@@ -434,7 +434,7 @@ class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
Type type = this->getTypeConverter()->convertType(op.getType());
if (!type || !(isa_and_nonnull<IntegerType>(type) ||
- emitc::isPointerWideType(type))) {
+ emitc::isPointerWideType(type))) {
return rewriter.notifyMatchFailure(
op, "expected integer or size_t/ssize_t/ptrdiff_t type");
}
@@ -518,7 +518,7 @@ class ShiftOpConversion : public OpConversionPattern<ArithOp> {
Type type = this->getTypeConverter()->convertType(op.getType());
if (!type || !(isa_and_nonnull<IntegerType>(type) ||
- emitc::isPointerWideType(type))) {
+ emitc::isPointerWideType(type))) {
return rewriter.notifyMatchFailure(
op, "expected integer or size_t/ssize_t/ptrdiff_t type");
}
More information about the Mlir-commits
mailing list