[Mlir-commits] [mlir] [MLIR] Testing arith-to-emitc conversions using opaque types (PR #137936)

Niklas Degener llvmlistbot at llvm.org
Wed Apr 30 07:40:53 PDT 2025


https://github.com/ndegener-amd updated https://github.com/llvm/llvm-project/pull/137936

>From 40f1409ee810b18e2daed6b3ca6e596d90d054fa Mon Sep 17 00:00:00 2001
From: Niklas Degener <niklas.degener at amd.com>
Date: Thu, 24 Apr 2025 09:22:35 +0200
Subject: [PATCH 1/4] Added test cases for arith-to-emitc conversions using
 opaque types

---
 mlir/include/mlir/Dialect/EmitC/IR/EmitC.h    |  8 ++
 .../Conversion/ArithToEmitC/ArithToEmitC.cpp  | 84 ++++++++++++-------
 .../ArithToEmitC/ArithToEmitCPass.cpp         | 35 +++++++-
 mlir/lib/Dialect/EmitC/IR/EmitC.cpp           |  8 ++
 .../arith-to-emitc-unsupported.mlir           | 16 ----
 .../ArithToEmitC/arith-to-emitc.mlir          | 71 ++++++++++++++++
 6 files changed, 175 insertions(+), 47 deletions(-)

diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
index 57029c64ffd00..6adbb475cdbf8 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
@@ -46,6 +46,14 @@ bool isIntegerIndexOrOpaqueType(Type type);
 /// Determines whether \p type is a valid floating-point type in EmitC.
 bool isSupportedFloatType(mlir::Type type);
 
+/// Determines whether \p type is a valid floating-point or opaque type in
+/// EmitC.
+bool isFloatOrOpaqueType(mlir::Type type);
+
+/// Determines whether \p type is a valid integer or opaque type in
+/// EmitC.
+bool isIntegerOrOpaqueType(mlir::Type type);
+
 /// Determines whether \p type is a emitc.size_t/ssize_t type.
 bool isPointerWideType(mlir::Type type);
 
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index 359d7b2279639..88f8651618df0 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -67,6 +67,7 @@ Type adaptIntegralTypeSignedness(Type ty, bool needsUnsigned) {
 
 /// Insert a cast operation to type \p ty if \p val does not have this type.
 Value adaptValueType(Value val, ConversionPatternRewriter &rewriter, Type ty) {
+  assert(emitc::isSupportedEmitCType(val.getType()));
   return rewriter.createOrFold<emitc::CastOp>(val.getLoc(), ty, val);
 }
 
@@ -273,7 +274,8 @@ class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
                   ConversionPatternRewriter &rewriter) const override {
 
     Type type = adaptor.getLhs().getType();
-    if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) {
+    if (!type || !(emitc::isIntegerOrOpaqueType(type) ||
+                   emitc::isPointerWideType(type))) {
       return rewriter.notifyMatchFailure(
           op, "expected integer or size_t/ssize_t/ptrdiff_t type");
     }
@@ -328,7 +330,7 @@ class CastConversion : public OpConversionPattern<ArithOp> {
                   ConversionPatternRewriter &rewriter) const override {
 
     Type opReturnType = this->getTypeConverter()->convertType(op.getType());
-    if (!opReturnType || !(isa<IntegerType>(opReturnType) ||
+    if (!opReturnType || !(emitc::isIntegerOrOpaqueType(opReturnType) ||
                            emitc::isPointerWideType(opReturnType)))
       return rewriter.notifyMatchFailure(
           op, "expected integer or size_t/ssize_t/ptrdiff_t result type");
@@ -339,7 +341,7 @@ class CastConversion : public OpConversionPattern<ArithOp> {
     }
 
     Type operandType = adaptor.getIn().getType();
-    if (!operandType || !(isa<IntegerType>(operandType) ||
+    if (!operandType || !(emitc::isIntegerOrOpaqueType(operandType) ||
                           emitc::isPointerWideType(operandType)))
       return rewriter.notifyMatchFailure(
           op, "expected integer or size_t/ssize_t/ptrdiff_t operand type");
@@ -433,7 +435,8 @@ class BinaryUIOpConversion final : public OpConversionPattern<ArithOp> {
     if (!newRetTy)
       return rewriter.notifyMatchFailure(uiBinOp,
                                          "converting result type failed");
-    if (!isa<IntegerType>(newRetTy)) {
+
+    if (!emitc::isIntegerOrOpaqueType(newRetTy)) {
       return rewriter.notifyMatchFailure(uiBinOp, "expected integer type");
     }
     Type unsignedType =
@@ -441,8 +444,8 @@ class BinaryUIOpConversion final : public OpConversionPattern<ArithOp> {
     if (!unsignedType)
       return rewriter.notifyMatchFailure(uiBinOp,
                                          "converting result type failed");
-    Value lhsAdapted = adaptValueType(uiBinOp.getLhs(), rewriter, unsignedType);
-    Value rhsAdapted = adaptValueType(uiBinOp.getRhs(), rewriter, unsignedType);
+    Value lhsAdapted = adaptValueType(adaptor.getLhs(), rewriter, unsignedType);
+    Value rhsAdapted = adaptValueType(adaptor.getRhs(), rewriter, unsignedType);
 
     auto newDivOp =
         rewriter.create<EmitCOp>(uiBinOp.getLoc(), unsignedType,
@@ -463,7 +466,8 @@ class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
                   ConversionPatternRewriter &rewriter) const override {
 
     Type type = this->getTypeConverter()->convertType(op.getType());
-    if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) {
+    if (!type || !(emitc::isIntegerOrOpaqueType(type) ||
+                   emitc::isPointerWideType(type))) {
       return rewriter.notifyMatchFailure(
           op, "expected integer or size_t/ssize_t/ptrdiff_t type");
     }
@@ -506,7 +510,7 @@ class BitwiseOpConversion : public OpConversionPattern<ArithOp> {
                   ConversionPatternRewriter &rewriter) const override {
 
     Type type = this->getTypeConverter()->convertType(op.getType());
-    if (!isa_and_nonnull<IntegerType>(type)) {
+    if (!type || !emitc::isIntegerOrOpaqueType(type)) {
       return rewriter.notifyMatchFailure(
           op,
           "expected integer type, vector/tensor support not yet implemented");
@@ -546,7 +550,9 @@ class ShiftOpConversion : public OpConversionPattern<ArithOp> {
                   ConversionPatternRewriter &rewriter) const override {
 
     Type type = this->getTypeConverter()->convertType(op.getType());
-    if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) {
+    bool retIsOpaque = isa_and_nonnull<emitc::OpaqueType>(type);
+    if (!type || (!retIsOpaque && !(isa<IntegerType>(type) ||
+                                    emitc::isPointerWideType(type)))) {
       return rewriter.notifyMatchFailure(
           op, "expected integer or size_t/ssize_t/ptrdiff_t type");
     }
@@ -572,21 +578,33 @@ class ShiftOpConversion : public OpConversionPattern<ArithOp> {
           op.getLoc(), rhsType, "sizeof", ArrayRef<Value>{eight});
       width = rewriter.create<emitc::MulOp>(op.getLoc(), rhsType, eight,
                                             sizeOfCall.getResult(0));
-    } else {
+    } else if (!retIsOpaque) {
       width = rewriter.create<emitc::ConstantOp>(
           op.getLoc(), rhsType,
           rewriter.getIntegerAttr(rhsType, type.getIntOrFloatBitWidth()));
+    } else {
+      width = rewriter.create<emitc::ConstantOp>(
+          op.getLoc(), rhsType,
+          emitc::OpaqueAttr::get(rhsType.getContext(),
+                                 "opaque_shift_bitwidth"));
     }
 
     Value excessCheck = rewriter.create<emitc::CmpOp>(
         op.getLoc(), rewriter.getI1Type(), emitc::CmpPredicate::lt, rhs, width);
 
     // Any concrete value is a valid refinement of poison.
-    Value poison = rewriter.create<emitc::ConstantOp>(
-        op.getLoc(), arithmeticType,
-        (isa<IntegerType>(arithmeticType)
-             ? rewriter.getIntegerAttr(arithmeticType, 0)
-             : rewriter.getIndexAttr(0)));
+    Value poison;
+    if (retIsOpaque) {
+      poison = rewriter.create<emitc::ConstantOp>(
+          op.getLoc(), arithmeticType,
+          emitc::OpaqueAttr::get(rhsType.getContext(), "opaque_shift_poison"));
+    } else {
+      poison = rewriter.create<emitc::ConstantOp>(
+          op.getLoc(), arithmeticType,
+          (isa<IntegerType>(arithmeticType)
+               ? rewriter.getIntegerAttr(arithmeticType, 0)
+               : rewriter.getIndexAttr(0)));
+    }
 
     emitc::ExpressionOp ternary = rewriter.create<emitc::ExpressionOp>(
         op.getLoc(), arithmeticType, /*do_not_inline=*/false);
@@ -663,19 +681,23 @@ class FtoICastOpConversion : public OpConversionPattern<CastOp> {
     if (!dstType)
       return rewriter.notifyMatchFailure(castOp, "type conversion failed");
 
+    Type actualResultType = dstType;
+
     // Float-to-i1 casts are not supported: any value with 0 < value < 1 must be
     // truncated to 0, whereas a boolean conversion would return true.
-    if (!emitc::isSupportedIntegerType(dstType) || dstType.isInteger(1))
-      return rewriter.notifyMatchFailure(castOp,
-                                         "unsupported cast destination type");
-
-    // Convert to unsigned if it's the "ui" variant
-    // Signless is interpreted as signed, so no need to cast for "si"
-    Type actualResultType = dstType;
-    if (isa<arith::FPToUIOp>(castOp)) {
-      actualResultType =
-          rewriter.getIntegerType(dstType.getIntOrFloatBitWidth(),
-                                  /*isSigned=*/false);
+    bool dstIsOpaque = isa<emitc::OpaqueType>(dstType);
+    if (!dstIsOpaque) {
+      if (!emitc::isSupportedIntegerType(dstType) || dstType.isInteger(1))
+        return rewriter.notifyMatchFailure(castOp,
+                                           "unsupported cast destination type");
+
+      // Convert to unsigned if it's the "ui" variant
+      // Signless is interpreted as signed, so no need to cast for "si"
+      if (isa<arith::FPToUIOp>(castOp)) {
+        actualResultType =
+            rewriter.getIntegerType(dstType.getIntOrFloatBitWidth(),
+                                    /*isSigned=*/false);
+      }
     }
 
     Value result = rewriter.create<emitc::CastOp>(
@@ -702,7 +724,9 @@ class ItoFCastOpConversion : public OpConversionPattern<CastOp> {
                   ConversionPatternRewriter &rewriter) const override {
     // Vectors in particular are not supported
     Type operandType = adaptor.getIn().getType();
-    if (!emitc::isSupportedIntegerType(operandType))
+    bool opIsOpaque = isa<emitc::OpaqueType>(operandType);
+
+    if (!(opIsOpaque || emitc::isSupportedIntegerType(operandType)))
       return rewriter.notifyMatchFailure(castOp,
                                          "unsupported cast source type");
 
@@ -717,7 +741,7 @@ class ItoFCastOpConversion : public OpConversionPattern<CastOp> {
     // Convert to unsigned if it's the "ui" variant
     // Signless is interpreted as signed, so no need to cast for "si"
     Type actualOperandType = operandType;
-    if (isa<arith::UIToFPOp>(castOp)) {
+    if (!opIsOpaque && isa<arith::UIToFPOp>(castOp)) {
       actualOperandType =
           rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(),
                                   /*isSigned=*/false);
@@ -745,7 +769,7 @@ class FpCastOpConversion : public OpConversionPattern<CastOp> {
                   ConversionPatternRewriter &rewriter) const override {
     // Vectors in particular are not supported.
     Type operandType = adaptor.getIn().getType();
-    if (!emitc::isSupportedFloatType(operandType))
+    if (!emitc::isFloatOrOpaqueType(operandType))
       return rewriter.notifyMatchFailure(castOp,
                                          "unsupported cast source type");
     if (auto roundingModeOp =
@@ -759,7 +783,7 @@ class FpCastOpConversion : public OpConversionPattern<CastOp> {
     if (!dstType)
       return rewriter.notifyMatchFailure(castOp, "type conversion failed");
 
-    if (!emitc::isSupportedFloatType(dstType))
+    if (!emitc::isFloatOrOpaqueType(dstType))
       return rewriter.notifyMatchFailure(castOp,
                                          "unsupported cast destination type");
 
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp
index 45a088ed144f1..30d5eed9dceaa 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp
@@ -30,9 +30,42 @@ namespace {
 struct ConvertArithToEmitC
     : public impl::ConvertArithToEmitCBase<ConvertArithToEmitC> {
   void runOnOperation() override;
+
+  /// Applies conversion to opaque types for f80 and i80 types, both unsupported
+  /// in emitc. Used to test the pass with opaque types.
+  void populateOpaqueTypeConversions(TypeConverter &converter);
 };
 } // namespace
 
+void ConvertArithToEmitC::populateOpaqueTypeConversions(
+    TypeConverter &converter) {
+  converter.addConversion([](Type type) -> std::optional<Type> {
+    if (type.isF80())
+      return emitc::OpaqueType::get(type.getContext(), "f80");
+    if (type.isInteger() && type.getIntOrFloatBitWidth() == 80)
+      return emitc::OpaqueType::get(type.getContext(), "i80");
+    return type;
+  });
+
+  converter.addTypeAttributeConversion(
+      [](Type type,
+         Attribute attrToConvert) -> TypeConverter::AttributeConversionResult {
+        if (auto floatAttr = llvm::dyn_cast<FloatAttr>(attrToConvert)) {
+          if (floatAttr.getType().isF80()) {
+            return emitc::OpaqueAttr::get(type.getContext(), "f80");
+          }
+          return {};
+        }
+        if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attrToConvert)) {
+          if (intAttr.getType().isInteger() &&
+              intAttr.getType().getIntOrFloatBitWidth() == 80) {
+            return emitc::OpaqueAttr::get(type.getContext(), "i80");
+          }
+        }
+        return {};
+      });
+}
+
 void ConvertArithToEmitC::runOnOperation() {
   ConversionTarget target(getContext());
 
@@ -42,8 +75,8 @@ void ConvertArithToEmitC::runOnOperation() {
   RewritePatternSet patterns(&getContext());
 
   TypeConverter typeConverter;
-  typeConverter.addConversion([](Type type) { return type; });
 
+  populateOpaqueTypeConversions(typeConverter);
   populateArithToEmitCPatterns(typeConverter, patterns);
 
   if (failed(
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index b4d7482554fbc..c34752f92c794 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -132,6 +132,14 @@ bool mlir::emitc::isSupportedFloatType(Type type) {
   return false;
 }
 
+bool mlir::emitc::isIntegerOrOpaqueType(Type type) {
+  return isa<emitc::OpaqueType>(type) || isSupportedIntegerType(type);
+}
+
+bool mlir::emitc::isFloatOrOpaqueType(Type type) {
+  return isa<emitc::OpaqueType>(type) || isSupportedFloatType(type);
+}
+
 bool mlir::emitc::isPointerWideType(Type type) {
   return isa<emitc::SignedSizeTType, emitc::SizeTType, emitc::PtrDiffTType>(
       type);
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
index 9850f336b5ad6..e652ed38a21d2 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
@@ -14,13 +14,6 @@ func.func @arith_cast_vector(%arg0: vector<5xf32>) -> vector<5xi32> {
   return %t: vector<5xi32>
 }
 
-// -----
-func.func @arith_cast_f80(%arg0: f80) -> i32 {
-  // expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
-  %t = arith.fptosi %arg0 : f80 to i32
-  return %t: i32
-}
-
 // -----
 
 func.func @arith_cast_f128(%arg0: f128) -> i32 {
@@ -29,15 +22,6 @@ func.func @arith_cast_f128(%arg0: f128) -> i32 {
   return %t: i32
 }
 
-
-// -----
-
-func.func @arith_cast_to_f80(%arg0: i32) -> f80 {
-  // expected-error @+1 {{failed to legalize operation 'arith.sitofp'}}
-  %t = arith.sitofp %arg0 : i32 to f80
-  return %t: f80
-}
-
 // -----
 
 func.func @arith_cast_to_f128(%arg0: i32) -> f128 {
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index cb1d092918f03..3db27db653a80 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -771,3 +771,74 @@ func.func @arith_truncf(%arg0: f64) -> f16 {
 
   return %truncd1 : f16
 }
+
+// -----
+
+func.func @float_opaque_conversion(%arg0: f80, %arg1: f80) {
+  // CHECK-LABEL: float_opaque_conversion
+  // CHECK-SAME: (%[[Arg0:[^ ]*]]: f80, %[[Arg1:[^ ]*]]: f80)
+
+  // CHECK-DAG: [[arg1_cast:[^ ]*]] = builtin.unrealized_conversion_cast %[[Arg1]] : f80 to !emitc.opaque<"f80"> 
+  // CHECK-DAG: [[arg0_cast:[^ ]*]] = builtin.unrealized_conversion_cast %[[Arg0]] : f80 to !emitc.opaque<"f80"> 
+  // CHECK: "emitc.constant"() <{value = #emitc.opaque<"f80">}> : () -> !emitc.opaque<"f80">
+  %10 = arith.constant 0.0 : f80
+  // CHECK: emitc.add [[arg0_cast]], [[arg1_cast]] : (!emitc.opaque<"f80">, !emitc.opaque<"f80">) -> !emitc.opaque<"f80">
+  %2 = arith.addf %arg0, %arg1 : f80
+  // CHECK: [[EQ:[^ ]*]] = emitc.cmp eq, [[arg0_cast]], [[arg1_cast]] : (!emitc.opaque<"f80">, !emitc.opaque<"f80">) -> i1
+  // CHECK: [[NotNaNArg0:[^ ]*]] = emitc.cmp eq, [[arg0_cast]], [[arg0_cast]] : (!emitc.opaque<"f80">, !emitc.opaque<"f80">) -> i1
+  // CHECK: [[NotNaNArg1:[^ ]*]] = emitc.cmp eq, [[arg1_cast]], [[arg1_cast]] : (!emitc.opaque<"f80">, !emitc.opaque<"f80">) -> i1
+  // CHECK: [[Ordered:[^ ]*]] = emitc.logical_and [[NotNaNArg0]], [[NotNaNArg1]] : i1, i1
+  // CHECK: emitc.logical_and [[Ordered]], [[EQ]] : i1, i1
+  %11 = arith.cmpf oeq, %arg0, %arg1 : f80
+  // CHECK: emitc.unary_minus [[arg0_cast]] : (!emitc.opaque<"f80">) -> !emitc.opaque<"f80">
+  %12 = arith.negf %arg0 : f80
+  // CHECK: [[V0:[^ ]*]] = emitc.cast [[arg0_cast]] : !emitc.opaque<"f80"> to ui32
+  // CHECK: [[V1:[^ ]*]] = emitc.cast [[V0]] : ui32 to i32
+  %7 = arith.fptoui %arg0 : f80 to i32
+  // CHECK: emitc.cast [[V1]] : i32 to !emitc.opaque<"f80">
+  %8 = arith.sitofp %7 : i32 to f80
+  // CHECK: [[trunc:[^ ]*]] = emitc.cast [[arg0_cast]] : !emitc.opaque<"f80"> to f32
+  %13 = arith.truncf %arg0 : f80 to f32
+  // CHECK: emitc.cast [[trunc]] : f32 to !emitc.opaque<"f80">
+  %15 = arith.extf %13 : f32 to f80
+  return
+}
+
+// -----
+
+func.func @int_opaque_conversion(%arg0: i80, %arg1: i80, %arg2: i1) {
+  // CHECK-LABEL: int_opaque_conversion
+  // CHECK-SAME: (%[[Arg0:[^ ]*]]: i80, %[[Arg1:[^ ]*]]: i80, %[[Arg2:[^ ]*]]: i1)
+
+  // CHECK-DAG: [[arg1_cast:[^ ]*]] = builtin.unrealized_conversion_cast %[[Arg1]] : i80 to !emitc.opaque<"i80"> 
+  // CHECK-DAG: [[arg0_cast:[^ ]*]] = builtin.unrealized_conversion_cast %[[Arg0]] : i80 to !emitc.opaque<"i80">
+  // CHECK: "emitc.constant"() <{value = #emitc.opaque<"i80">}> : () -> !emitc.opaque<"i80">
+  %10 = arith.constant 0 : i80
+  // CHECK: emitc.div [[arg0_cast]], [[arg1_cast]] : (!emitc.opaque<"i80">, !emitc.opaque<"i80">) -> !emitc.opaque<"i80">
+  %3 = arith.divui %arg0, %arg1 : i80
+  // CHECK: emitc.add [[arg0_cast]], [[arg1_cast]] : (!emitc.opaque<"i80">, !emitc.opaque<"i80">) -> !emitc.opaque<"i80">
+  %2 = arith.addi %arg0, %arg1 : i80
+  // CHECK: emitc.bitwise_and [[arg0_cast]], [[arg1_cast]] : (!emitc.opaque<"i80">, !emitc.opaque<"i80">) -> !emitc.opaque<"i80">
+  %14 = arith.andi %arg0, %arg1 : i80
+  // CHECK: [[Bitwidth:[^ ]*]] = "emitc.constant"() <{value = #emitc.opaque<"opaque_shift_bitwidth">}> : () -> !emitc.opaque<"i80">
+  // CHECK: [[LT:[^ ]*]] = emitc.cmp lt, [[arg1_cast]], [[Bitwidth]] : (!emitc.opaque<"i80">, !emitc.opaque<"i80">) -> i1
+  // CHECK: [[Poison:[^ ]*]] = "emitc.constant"() <{value = #emitc.opaque<"opaque_shift_poison">}> : () -> !emitc.opaque<"i80">
+  // CHECK: [[Exp:[^ ]*]] = emitc.expression : !emitc.opaque<"i80"> {
+  // CHECK: [[LShift:[^ ]*]] = emitc.bitwise_left_shift [[arg0_cast]], [[arg1_cast]] : (!emitc.opaque<"i80">, !emitc.opaque<"i80">) -> !emitc.opaque<"i80">
+  // CHECK: emitc.conditional [[LT]], [[LShift]], [[Poison]] : !emitc.opaque<"i80">
+  // CHECK: emitc.yield {{.*}} : !emitc.opaque<"i80">
+  // CHECK: }
+  %12 = arith.shli %arg0, %arg1 : i80
+  // CHECK: emitc.cmp eq, [[arg0_cast]], [[arg1_cast]] : (!emitc.opaque<"i80">, !emitc.opaque<"i80">) -> i1
+  %11 = arith.cmpi eq, %arg0, %arg1 : i80
+  // CHECK: emitc.conditional %[[Arg2]], [[arg0_cast]], [[arg1_cast]] : !emitc.opaque<"i80">
+  %13 = arith.select %arg2, %arg0, %arg1 : i80
+  // CHECK: [[V0:[^ ]*]] = emitc.cast [[arg0_cast]] : !emitc.opaque<"i80"> to ui8
+  // CHECK: emitc.cast [[V0]] : ui8 to i8
+  %15 = arith.trunci %arg0 : i80 to i8
+  // CHECK: [[V1:[^ ]*]] = emitc.cast [[arg0_cast]] : !emitc.opaque<"i80"> to f32
+  %9 = arith.uitofp %arg0 : i80 to f32
+  // CHECK: emitc.cast [[V1]] : f32 to !emitc.opaque<"i80">
+  %6 = arith.fptosi %9 : f32 to i80
+  return
+}

>From 94ae63f8d555c693d3b921e5e8a948dd87b29768 Mon Sep 17 00:00:00 2001
From: ndegener-amd <niklas.degener at amd.com>
Date: Wed, 30 Apr 2025 08:34:31 -0600
Subject: [PATCH 2/4] Removed deprecated emitc prefixes from emitc.expression

---
 mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index 3db27db653a80..c1247cb97e440 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -824,9 +824,9 @@ func.func @int_opaque_conversion(%arg0: i80, %arg1: i80, %arg2: i1) {
   // CHECK: [[LT:[^ ]*]] = emitc.cmp lt, [[arg1_cast]], [[Bitwidth]] : (!emitc.opaque<"i80">, !emitc.opaque<"i80">) -> i1
   // CHECK: [[Poison:[^ ]*]] = "emitc.constant"() <{value = #emitc.opaque<"opaque_shift_poison">}> : () -> !emitc.opaque<"i80">
   // CHECK: [[Exp:[^ ]*]] = emitc.expression : !emitc.opaque<"i80"> {
-  // CHECK: [[LShift:[^ ]*]] = emitc.bitwise_left_shift [[arg0_cast]], [[arg1_cast]] : (!emitc.opaque<"i80">, !emitc.opaque<"i80">) -> !emitc.opaque<"i80">
-  // CHECK: emitc.conditional [[LT]], [[LShift]], [[Poison]] : !emitc.opaque<"i80">
-  // CHECK: emitc.yield {{.*}} : !emitc.opaque<"i80">
+  // CHECK: [[LShift:[^ ]*]] = bitwise_left_shift [[arg0_cast]], [[arg1_cast]] : (!emitc.opaque<"i80">, !emitc.opaque<"i80">) -> !emitc.opaque<"i80">
+  // CHECK: conditional [[LT]], [[LShift]], [[Poison]] : !emitc.opaque<"i80">
+  // CHECK: yield {{.*}} : !emitc.opaque<"i80">
   // CHECK: }
   %12 = arith.shli %arg0, %arg1 : i80
   // CHECK: emitc.cmp eq, [[arg0_cast]], [[arg1_cast]] : (!emitc.opaque<"i80">, !emitc.opaque<"i80">) -> i1

>From da471d1809c94e718a241080cf7675462be9037b Mon Sep 17 00:00:00 2001
From: ndegener-amd <niklas.degener at amd.com>
Date: Wed, 30 Apr 2025 08:35:15 -0600
Subject: [PATCH 3/4] Removed new "unsupported" test-case for f80

---
 .../ArithToEmitC/arith-to-emitc-unsupported.mlir          | 8 --------
 1 file changed, 8 deletions(-)

diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
index e652ed38a21d2..a71de136e6d76 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
@@ -64,14 +64,6 @@ func.func @arith_cmpf_tensor(%arg0: tensor<5xf32>, %arg1: tensor<5xf32>) -> tens
 
 // -----
 
-func.func @arith_negf_f80(%arg0: f80) -> f80 {
-  // expected-error @+1 {{failed to legalize operation 'arith.negf'}}
-  %n = arith.negf %arg0 : f80
-  return %n: f80
-}
-
-// -----
-
 func.func @arith_negf_tensor(%arg0: tensor<5xf32>) -> tensor<5xf32> {
   // expected-error @+1 {{failed to legalize operation 'arith.negf'}}
   %n = arith.negf %arg0 : tensor<5xf32>

>From 665eadae42c0f9bf0a17e7df79ba927921b58437 Mon Sep 17 00:00:00 2001
From: ndegener-amd <niklas.degener at amd.com>
Date: Wed, 30 Apr 2025 08:40:14 -0600
Subject: [PATCH 4/4] Fixed attribute conversion for constantOps, legalized
 CmpF, NegF, FtoI, ItoF conversions for opaque types

---
 .../Conversion/ArithToEmitC/ArithToEmitC.cpp  | 20 ++++++++++++++-----
 .../ArithToEmitC/ArithToEmitCPass.cpp         |  4 ++--
 2 files changed, 17 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index 88f8651618df0..17bfe11a4429e 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/EmitC/IR/EmitC.h"
 #include "mlir/Dialect/EmitC/Transforms/TypeConversions.h"
+#include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/Transforms/DialectConversion.h"
@@ -39,8 +40,17 @@ class ArithConstantOpConversionPattern
     Type newTy = this->getTypeConverter()->convertType(arithConst.getType());
     if (!newTy)
       return rewriter.notifyMatchFailure(arithConst, "type conversion failed");
+
+    std::optional<Attribute> opAttrib =
+        this->getTypeConverter()->convertTypeAttribute(
+            adaptor.getValue().getType(), adaptor.getValue());
+    if (!opAttrib) {
+      return rewriter.notifyMatchFailure(arithConst,
+                                         "attribute conversion failed");
+    }
+
     rewriter.replaceOpWithNewOp<emitc::ConstantOp>(arithConst, newTy,
-                                                   adaptor.getValue());
+                                                   opAttrib.value());
     return success();
   }
 };
@@ -79,7 +89,7 @@ class CmpFOpConversion : public OpConversionPattern<arith::CmpFOp> {
   matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
 
-    if (!isa<FloatType>(adaptor.getRhs().getType())) {
+    if (!emitc::isFloatOrOpaqueType(adaptor.getRhs().getType())) {
       return rewriter.notifyMatchFailure(op.getLoc(),
                                          "cmpf currently only supported on "
                                          "floats, not tensors/vectors thereof");
@@ -309,7 +319,7 @@ class NegFOpConversion : public OpConversionPattern<arith::NegFOp> {
           "negf currently only supports scalar types, not vectors or tensors");
     }
 
-    if (!emitc::isSupportedFloatType(adaptedOpType)) {
+    if (!emitc::isFloatOrOpaqueType(adaptedOpType)) {
       return rewriter.notifyMatchFailure(
           op.getLoc(), "floating-point type is not supported by EmitC");
     }
@@ -673,7 +683,7 @@ class FtoICastOpConversion : public OpConversionPattern<CastOp> {
                   ConversionPatternRewriter &rewriter) const override {
 
     Type operandType = adaptor.getIn().getType();
-    if (!emitc::isSupportedFloatType(operandType))
+    if (!emitc::isFloatOrOpaqueType(operandType))
       return rewriter.notifyMatchFailure(castOp,
                                          "unsupported cast source type");
 
@@ -734,7 +744,7 @@ class ItoFCastOpConversion : public OpConversionPattern<CastOp> {
     if (!dstType)
       return rewriter.notifyMatchFailure(castOp, "type conversion failed");
 
-    if (!emitc::isSupportedFloatType(dstType))
+    if (!emitc::isFloatOrOpaqueType(dstType))
       return rewriter.notifyMatchFailure(castOp,
                                          "unsupported cast destination type");
 
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp
index 30d5eed9dceaa..e64e4176d3baa 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp
@@ -54,7 +54,7 @@ void ConvertArithToEmitC::populateOpaqueTypeConversions(
           if (floatAttr.getType().isF80()) {
             return emitc::OpaqueAttr::get(type.getContext(), "f80");
           }
-          return {};
+          return attrToConvert;
         }
         if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attrToConvert)) {
           if (intAttr.getType().isInteger() &&
@@ -62,7 +62,7 @@ void ConvertArithToEmitC::populateOpaqueTypeConversions(
             return emitc::OpaqueAttr::get(type.getContext(), "i80");
           }
         }
-        return {};
+        return attrToConvert;
       });
 }
 



More information about the Mlir-commits mailing list