[Mlir-commits] [mlir] [emitc][mlir] Add an option to cast array type to ptr type (PR #126385)

Andrey Timonin llvmlistbot at llvm.org
Sat Feb 8 11:22:19 PST 2025


https://github.com/EtoAndruwa created https://github.com/llvm/llvm-project/pull/126385

None

>From e92a135823f9d7237e0451328c06e0ff2e306307 Mon Sep 17 00:00:00 2001
From: Andrey Timonin <timonina1909 at gmail.com>
Date: Sat, 8 Feb 2025 21:40:30 +0300
Subject: [PATCH] [emitc][mlir] Add an option to cast array type to ptr type

---
 mlir/include/mlir/Dialect/EmitC/IR/EmitC.td |  3 +--
 mlir/lib/Dialect/EmitC/IR/EmitC.cpp         | 11 ++++++-----
 mlir/test/Dialect/EmitC/invalid_ops.mlir    | 14 +++++++++++---
 mlir/test/Dialect/EmitC/ops.mlir            |  5 +++++
 4 files changed, 23 insertions(+), 10 deletions(-)

diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index 4fbce995ce5b80..360f2e84343636 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -266,8 +266,7 @@ def EmitC_CallOpaqueOp : EmitC_Op<"call_opaque", [CExpression]> {
 
 def EmitC_CastOp : EmitC_Op<"cast",
     [CExpression,
-     DeclareOpInterfaceMethods<CastOpInterface>,
-     SameOperandsAndResultShape]> {
+     DeclareOpInterfaceMethods<CastOpInterface>]> {
   let summary = "Cast operation";
   let description = [{
     The `emitc.cast` operation performs an explicit type conversion and is emitted
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 728a2d33f46e7f..01effa5734caa6 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -247,11 +247,12 @@ LogicalResult emitc::AssignOp::verify() {
 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
   Type input = inputs.front(), output = outputs.front();
 
-  return (
-      (emitc::isIntegerIndexOrOpaqueType(input) ||
-       emitc::isSupportedFloatType(input) || isa<emitc::PointerType>(input)) &&
-      (emitc::isIntegerIndexOrOpaqueType(output) ||
-       emitc::isSupportedFloatType(output) || isa<emitc::PointerType>(output)));
+  return ((emitc::isIntegerIndexOrOpaqueType(input) ||
+           emitc::isSupportedFloatType(input) ||
+           isa<emitc::PointerType>(input) || isa<emitc::ArrayType>(input)) &&
+          (emitc::isIntegerIndexOrOpaqueType(output) ||
+           emitc::isSupportedFloatType(output) ||
+           isa<emitc::PointerType>(output)));
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir
index a0d8d7f59de115..c40195dd3473aa 100644
--- a/mlir/test/Dialect/EmitC/invalid_ops.mlir
+++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir
@@ -130,9 +130,17 @@ func.func @cast_tensor(%arg : tensor<f32>) {
 
 // -----
 
-func.func @cast_array(%arg : !emitc.array<4xf32>) {
-    // expected-error @+1 {{'emitc.cast' op operand type '!emitc.array<4xf32>' and result type '!emitc.array<4xf32>' are cast incompatible}}
-    %1 = emitc.cast %arg: !emitc.array<4xf32> to !emitc.array<4xf32>
+func.func @cast_to_array(%arg : f32) {
+    // expected-error @+1 {{'emitc.cast' op operand type 'f32' and result type '!emitc.array<4xf32>' are cast incompatible}}
+    %1 = emitc.cast %arg: f32 to !emitc.array<4xf32>
+    return
+}
+
+// -----
+
+func.func @cast_pointer_to_array(%arg : !emitc.ptr<i32>) {
+    // expected-error @+1 {{'emitc.cast' op operand type '!emitc.ptr<i32>' and result type '!emitc.array<3xi32>' are cast incompatible}}
+    %1 = emitc.cast %arg: !emitc.ptr<i32> to !emitc.array<3xi32>
     return
 }
 
diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir
index 7fd0a2d020397b..c6f90f56008555 100644
--- a/mlir/test/Dialect/EmitC/ops.mlir
+++ b/mlir/test/Dialect/EmitC/ops.mlir
@@ -39,6 +39,11 @@ func.func @cast(%arg0: i32) {
   return
 }
 
+func.func @cast_array_to_pointer(%arg0: !emitc.array<3xi32>) {
+  %1 = emitc.cast %arg0: !emitc.array<3xi32> to !emitc.ptr<i32>
+  return
+}
+
 func.func @c() {
   %1 = "emitc.constant"(){value = 42 : i32} : () -> i32
   %2 = "emitc.constant"(){value = 42 : index} : () -> !emitc.size_t



More information about the Mlir-commits mailing list