[Mlir-commits] [mlir] 1e1bf79 - [mlir][emitc] Add an option to cast array type to ptr type (#126385)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Feb 20 08:40:10 PST 2025


Author: Andrey Timonin
Date: 2025-02-20T17:40:06+01:00
New Revision: 1e1bf7971b1b8c74aa4de2c055c402d0085e87b8

URL: https://github.com/llvm/llvm-project/commit/1e1bf7971b1b8c74aa4de2c055c402d0085e87b8
DIFF: https://github.com/llvm/llvm-project/commit/1e1bf7971b1b8c74aa4de2c055c402d0085e87b8.diff

LOG: [mlir][emitc] Add an option to cast array type to ptr type (#126385)

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
    mlir/lib/Dialect/EmitC/IR/EmitC.cpp
    mlir/test/Dialect/EmitC/invalid_ops.mlir
    mlir/test/Dialect/EmitC/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index 15f3a5a4742c0..fadee23b25175 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -313,8 +313,7 @@ def EmitC_CallOpaqueOp : EmitC_Op<"call_opaque", [CExpression]> {
 
 def EmitC_CastOp : EmitC_Op<"cast",
     [CExpression,
-     DeclareOpInterfaceMethods<CastOpInterface>,
-     SameOperandsAndResultShape]> {
+     DeclareOpInterfaceMethods<CastOpInterface>]> {
   let summary = "Cast operation";
   let description = [{
     The `emitc.cast` operation performs an explicit type conversion and is emitted

diff  --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index eb7ffe2e032c4..b4d7482554fbc 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -305,6 +305,14 @@ LogicalResult emitc::AssignOp::verify() {
 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
   Type input = inputs.front(), output = outputs.front();
 
+  if (auto arrayType = dyn_cast<emitc::ArrayType>(input)) {
+    if (auto pointerType = dyn_cast<emitc::PointerType>(output)) {
+      return (arrayType.getElementType() == pointerType.getPointee()) &&
+             arrayType.getShape().size() == 1 && arrayType.getShape()[0] >= 1;
+    }
+    return false;
+  }
+
   return (
       (emitc::isIntegerIndexOrOpaqueType(input) ||
        emitc::isSupportedFloatType(input) || isa<emitc::PointerType>(input)) &&
@@ -757,9 +765,9 @@ void IfOp::print(OpAsmPrinter &p) {
 
 /// Given the region at `index`, or the parent operation if `index` is None,
 /// return the successor regions. These are the regions that may be selected
-/// during the flow of control. `operands` is a set of optional attributes that
-/// correspond to a constant value for each operand, or null if that operand is
-/// not a constant.
+/// during the flow of control. `operands` is a set of optional attributes
+/// that correspond to a constant value for each operand, or null if that
+/// operand is not a constant.
 void IfOp::getSuccessorRegions(RegionBranchPoint point,
                                SmallVectorImpl<RegionSuccessor> &regions) {
   // The `then` and the `else` region branch back to the parent operation.
@@ -1086,8 +1094,8 @@ emitc::ArrayType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
 LogicalResult mlir::emitc::LValueType::verify(
     llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
     mlir::Type value) {
-  // Check that the wrapped type is valid. This especially forbids nested lvalue
-  // types.
+  // Check that the wrapped type is valid. This especially forbids nested
+  // lvalue types.
   if (!isSupportedEmitCType(value))
     return emitError()
            << "!emitc.lvalue must wrap supported emitc type, but got " << value;

diff  --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir
index 13bd96f6d9fb4..5b4e3f92f6d53 100644
--- a/mlir/test/Dialect/EmitC/invalid_ops.mlir
+++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir
@@ -130,9 +130,41 @@ func.func @cast_tensor(%arg : tensor<f32>) {
 
 // -----
 
-func.func @cast_array(%arg : !emitc.array<4xf32>) {
-    // expected-error @+1 {{'emitc.cast' op operand type '!emitc.array<4xf32>' and result type '!emitc.array<4xf32>' are cast incompatible}}
-    %1 = emitc.cast %arg: !emitc.array<4xf32> to !emitc.array<4xf32>
+func.func @cast_to_array(%arg : f32) {
+    // expected-error @+1 {{'emitc.cast' op operand type 'f32' and result type '!emitc.array<4xf32>' are cast incompatible}}
+    %1 = emitc.cast %arg: f32 to !emitc.array<4xf32>
+    return
+}
+
+// -----
+
+func.func @cast_multidimensional_array(%arg : !emitc.array<1x2xi32>) {
+    // expected-error @+1 {{'emitc.cast' op operand type '!emitc.array<1x2xi32>' and result type '!emitc.ptr<i32>' are cast incompatible}}
+    %1 = emitc.cast %arg: !emitc.array<1x2xi32> to !emitc.ptr<i32>
+    return
+}
+
+// -----
+
+func.func @cast_array_zero_rank(%arg : !emitc.array<0xi32>) {
+    // expected-error @+1 {{'emitc.cast' op operand type '!emitc.array<0xi32>' and result type '!emitc.ptr<i32>' are cast incompatible}}
+    %1 = emitc.cast %arg: !emitc.array<0xi32> to !emitc.ptr<i32>
+    return
+}
+
+// -----
+
+func.func @cast_array_to_pointer_types_mismatch(%arg : !emitc.array<3xi32>) {
+    // expected-error @+1 {{'emitc.cast' op operand type '!emitc.array<3xi32>' and result type '!emitc.ptr<f16>' are cast incompatible}}
+    %1 = emitc.cast %arg: !emitc.array<3xi32> to !emitc.ptr<f16>
+    return
+}
+
+// -----
+
+func.func @cast_pointer_to_array(%arg : !emitc.ptr<i32>) {
+    // expected-error @+1 {{'emitc.cast' op operand type '!emitc.ptr<i32>' and result type '!emitc.array<3xi32>' are cast incompatible}}
+    %1 = emitc.cast %arg: !emitc.ptr<i32> to !emitc.array<3xi32>
     return
 }
 

diff  --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir
index 645009bcc3c36..36d12e763afc7 100644
--- a/mlir/test/Dialect/EmitC/ops.mlir
+++ b/mlir/test/Dialect/EmitC/ops.mlir
@@ -39,6 +39,11 @@ func.func @cast(%arg0: i32) {
   return
 }
 
+func.func @cast_array_to_pointer(%arg0: !emitc.array<3xi32>) {
+  %1 = emitc.cast %arg0: !emitc.array<3xi32> to !emitc.ptr<i32>
+  return
+}
+
 func.func @c() {
   %1 = "emitc.constant"(){value = 42 : i32} : () -> i32
   %2 = "emitc.constant"(){value = 42 : index} : () -> !emitc.size_t


        


More information about the Mlir-commits mailing list