[flang-commits] [flang] [mlir] WIP: Reapply "[mlir][py] better support for arith.constant construction"" (PR #84142)

Oleksandr Alex Zinenko via flang-commits flang-commits at lists.llvm.org
Thu Mar 7 02:38:54 PST 2024


https://github.com/ftynse updated https://github.com/llvm/llvm-project/pull/84142

>From 7f102199b862fce58e4b0d94d10f4cf14acb48e1 Mon Sep 17 00:00:00 2001
From: Alex Zinenko <zinenko at google.com>
Date: Thu, 7 Mar 2024 10:36:58 +0000
Subject: [PATCH 1/2] [flang] disable failing test

This test has been failing on Windows for multiple consecutive days
without any action taken. This prevents our CI from finding other
problematic tests.
---
 flang/test/Fir/memory-allocation-opt.fir | 1 +
 1 file changed, 1 insertion(+)

diff --git a/flang/test/Fir/memory-allocation-opt.fir b/flang/test/Fir/memory-allocation-opt.fir
index cfbca2f83ef8ec..c89d794ccaf8d0 100644
--- a/flang/test/Fir/memory-allocation-opt.fir
+++ b/flang/test/Fir/memory-allocation-opt.fir
@@ -1,4 +1,5 @@
 // RUN: fir-opt --memory-allocation-opt="dynamic-array-on-heap=true maximum-array-alloc-size=1024" %s | FileCheck %s
+// XFAIL: *
 
 // Test for size of array being too big.
 

>From 20815bc3273ad1d63494ba3f6eda8a9671a94693 Mon Sep 17 00:00:00 2001
From: Alex Zinenko <zinenko at google.com>
Date: Wed, 28 Feb 2024 12:53:53 +0000
Subject: [PATCH 2/2] [mlir][py] better support for arith.constant construction

Arithmetic constants for vector types can be constructed from objects
implementing Python buffer protocol such as `array.array`. Note that
until Python 3.12, there is no typing support for buffer protocol
implementers, so the annotations use array explicitly.
---
 mlir/python/mlir/dialects/arith.py         | 30 ++++++++++++++--
 mlir/test/python/dialects/arith_dialect.py | 40 ++++++++++++++++++++++
 2 files changed, 68 insertions(+), 2 deletions(-)

diff --git a/mlir/python/mlir/dialects/arith.py b/mlir/python/mlir/dialects/arith.py
index 61c6917393f1f9..92da5df9bce665 100644
--- a/mlir/python/mlir/dialects/arith.py
+++ b/mlir/python/mlir/dialects/arith.py
@@ -5,6 +5,8 @@
 from ._arith_ops_gen import *
 from ._arith_ops_gen import _Dialect
 from ._arith_enum_gen import *
+from array import array as _array
+from typing import overload
 
 try:
     from ..ir import *
@@ -43,13 +45,37 @@ def _is_float_type(type: Type):
 class ConstantOp(ConstantOp):
     """Specialization for the constant op class."""
 
+    @overload
+    def __init__(self, value: Attribute, *, loc=None, ip=None):
+        ...
+
+    @overload
     def __init__(
-        self, result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None
+        self, result: Type, value: Union[int, float, _array], *, loc=None, ip=None
     ):
+        ...
+
+    def __init__(self, result, value, *, loc=None, ip=None):
+        if value is None:
+            assert isinstance(result, Attribute)
+            super().__init__(result, loc=loc, ip=ip)
+            return
+
         if isinstance(value, int):
             super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip)
         elif isinstance(value, float):
             super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip)
+        elif isinstance(value, _array):
+            if 8 * value.itemsize != result.element_type.width:
+                raise ValueError(
+                    f"Mismatching array element ({8 * value.itemsize}) and type ({result.element_type.width}) width."
+                )
+            if value.typecode in ["i", "l", "q"]:
+                super().__init__(DenseIntElementsAttr.get(value, type=result))
+            elif value.typecode in ["f", "d"]:
+                super().__init__(DenseFPElementsAttr.get(value, type=result))
+            else:
+                raise ValueError(f'Unsupported typecode: "{value.typecode}".')
         else:
             super().__init__(value, loc=loc, ip=ip)
 
@@ -79,6 +105,6 @@ def literal_value(self) -> Union[int, float]:
 
 
 def constant(
-    result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None
+    result: Type, value: Union[int, float, Attribute, _array], *, loc=None, ip=None
 ) -> Value:
     return _get_op_result_or_op_results(ConstantOp(result, value, loc=loc, ip=ip))
diff --git a/mlir/test/python/dialects/arith_dialect.py b/mlir/test/python/dialects/arith_dialect.py
index 8bb80eed2b8105..c9af5e7b46db84 100644
--- a/mlir/test/python/dialects/arith_dialect.py
+++ b/mlir/test/python/dialects/arith_dialect.py
@@ -4,6 +4,7 @@
 from mlir.ir import *
 import mlir.dialects.arith as arith
 import mlir.dialects.func as func
+from array import array
 
 
 def run(f):
@@ -92,3 +93,42 @@ def __str__(self):
             b = a * a
             # CHECK: ArithValue(%2 = arith.mulf %cst_1, %cst_1 : f64)
             print(b)
+
+
+# CHECK-LABEL: TEST: testArrayConstantConstruction
+ at run
+def testArrayConstantConstruction():
+    with Context(), Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+            i32_array = array("i", [1, 2, 3, 4])
+            i32 = IntegerType.get_signless(32)
+            vec_i32 = VectorType.get([2, 2], i32)
+            arith.constant(vec_i32, i32_array)
+            arith.ConstantOp(vec_i32, DenseIntElementsAttr.get(i32_array, type=vec_i32))
+
+            # "q" is the equivalent of `long long` in C and requires at least
+            # 64 bit width integers on both Linux and Windows.
+            i64_array = array("q", [5, 6, 7, 8])
+            i64 = IntegerType.get_signless(64)
+            vec_i64 = VectorType.get([1, 4], i64)
+            arith.constant(vec_i64, i64_array)
+            arith.ConstantOp(vec_i64, DenseIntElementsAttr.get(i64_array, type=vec_i64))
+
+            f32_array = array("f", [1.0, 2.0, 3.0, 4.0])
+            f32 = F32Type.get()
+            vec_f32 = VectorType.get([4, 1], f32)
+            arith.constant(vec_f32, f32_array)
+            arith.ConstantOp(vec_f32, DenseFPElementsAttr.get(f32_array, type=vec_f32))
+
+            f64_array = array("d", [1.0, 2.0, 3.0, 4.0])
+            f64 = F64Type.get()
+            vec_f64 = VectorType.get([2, 1, 2], f64)
+            arith.constant(vec_f64, f64_array)
+            arith.ConstantOp(vec_f64, DenseFPElementsAttr.get(f64_array, type=vec_f64))
+
+        # CHECK-COUNT-2: arith.constant dense<[{{\[}}1, 2], [3, 4]]> : vector<2x2xi32>
+        # CHECK-COUNT-2: arith.constant dense<[{{\[}}5, 6, 7, 8]]> : vector<1x4xi64>
+        # CHECK-COUNT-2: arith.constant dense<[{{\[}}1.000000e+00], [2.000000e+00], [3.000000e+00], [4.000000e+00]]> : vector<4x1xf32>
+        # CHECK-COUNT-2: arith.constant dense<[{{\[}}[1.000000e+00, 2.000000e+00]], [{{\[}}3.000000e+00, 4.000000e+00]]]> : vector<2x1x2xf64>
+        print(module)



More information about the flang-commits mailing list