[Mlir-commits] [mlir] ca23c93 - [mlir][python] Create all missing attribute builders.

Ingo Müller llvmlistbot at llvm.org
Wed Sep 6 00:09:30 PDT 2023


Author: Ingo Müller
Date: 2023-09-06T07:09:25Z
New Revision: ca23c933bda9bdb4ef4cc8204834167f5f031401

URL: https://github.com/llvm/llvm-project/commit/ca23c933bda9bdb4ef4cc8204834167f5f031401
DIFF: https://github.com/llvm/llvm-project/commit/ca23c933bda9bdb4ef4cc8204834167f5f031401.diff

LOG: [mlir][python] Create all missing attribute builders.

This patch adds attribute builders for all buildable attributes from the
builtin dialect that did not previously have any. These builders can be
used to construct attributes of a particular type identified by a string
from a Python argument without knowing the details of how to pass that
Python argument to the attribute constructor. This is used, for example,
in the generated code of the Python bindings of ops.

The list of "all" attributes was produced with:

(
  grep -h "ods_ir.AttrBuilder.get" $(find ../build/ -name "*_ops_gen.py") \
    | cut -f2 -d"'"
  git grep -ho "^def [a-zA-Z0-9_]*" -- include/mlir/IR/CommonAttrConstraints.td \
    | cut -f2 -d" "
) | sort -u

Then, I only retained those that had an occurence in
`mlir/include/mlir/IR`. In particular, this drops many dialect-specific
attributes; registering those builders is something that those dialects
should do. Finally, I removed those attrbiutes that had a match in
`mlir/python/mlir/ir.py` already and implemented the remaining ones. The
only ones that still miss a builder now are the following:

* Represent more than one possible attribute type:
  - `Any.*Attr` (9x)
  - `IntNonNegative`
  - `IntPositive`
  - `IsNullAttr`
  - `ElementsAttr`
* I am not sure what "constant attributes" are:
  - `ConstBoolAttrFalse`
  - `ConstBoolAttrTrue`
  - `ConstUnitAttr`
* `Location` not exposed by Python bindings:
  - `LocationArrayAttr`
  - `LocationAttr`
* `get` function not implemented in Python bindings:
  - `StringElementsAttr`

This patch also fixes a compilation problem with
`I64SmallVectorArrayAttr`.

Reviewed By: makslevental, rkayaith

Differential Revision: https://reviews.llvm.org/D159403

Added: 
    

Modified: 
    mlir/include/mlir/IR/CommonAttrConstraints.td
    mlir/python/mlir/ir.py
    mlir/test/python/dialects/python_test.py
    mlir/test/python/python_test_ops.td

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/CommonAttrConstraints.td b/mlir/include/mlir/IR/CommonAttrConstraints.td
index dc091c36395601..0312ac7ec1d8df 100644
--- a/mlir/include/mlir/IR/CommonAttrConstraints.td
+++ b/mlir/include/mlir/IR/CommonAttrConstraints.td
@@ -612,7 +612,7 @@ def I64SmallVectorArrayAttr :
   let convertFromStorage = [{
     llvm::to_vector<4>(
       llvm::map_range($_self.getAsRange<mlir::IntegerAttr>(),
-      [](IntegerAttr attr) { return attr.getInt(); }));
+      [](mlir::IntegerAttr attr) { return attr.getInt(); }));
   }];
   let constBuilderCall = "$_builder.getI64ArrayAttr($0)";
 }

diff  --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index 36c49fe6f1d6bd..43553f3118a51f 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -16,16 +16,36 @@ def decorator_builder(func):
     return decorator_builder
 
 
+ at register_attribute_builder("AffineMapAttr")
+def _affineMapAttr(x, context):
+    return AffineMapAttr.get(x)
+
+
 @register_attribute_builder("BoolAttr")
 def _boolAttr(x, context):
     return BoolAttr.get(x, context=context)
 
 
+ at register_attribute_builder("DictionaryAttr")
+def _dictAttr(x, context):
+    return DictAttr.get(x, context=context)
+
+
 @register_attribute_builder("IndexAttr")
 def _indexAttr(x, context):
     return IntegerAttr.get(IndexType.get(context=context), x)
 
 
+ at register_attribute_builder("I1Attr")
+def _i1Attr(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(1, context=context), x)
+
+
+ at register_attribute_builder("I8Attr")
+def _i8Attr(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(8, context=context), x)
+
+
 @register_attribute_builder("I16Attr")
 def _i16Attr(x, context):
     return IntegerAttr.get(IntegerType.get_signless(16, context=context), x)
@@ -41,6 +61,16 @@ def _i64Attr(x, context):
     return IntegerAttr.get(IntegerType.get_signless(64, context=context), x)
 
 
+ at register_attribute_builder("SI1Attr")
+def _si1Attr(x, context):
+    return IntegerAttr.get(IntegerType.get_signed(1, context=context), x)
+
+
+ at register_attribute_builder("SI8Attr")
+def _i8Attr(x, context):
+    return IntegerAttr.get(IntegerType.get_signed(8, context=context), x)
+
+
 @register_attribute_builder("SI16Attr")
 def _si16Attr(x, context):
     return IntegerAttr.get(IntegerType.get_signed(16, context=context), x)
@@ -51,6 +81,36 @@ def _si32Attr(x, context):
     return IntegerAttr.get(IntegerType.get_signed(32, context=context), x)
 
 
+ at register_attribute_builder("SI64Attr")
+def _si64Attr(x, context):
+    return IntegerAttr.get(IntegerType.get_signed(64, context=context), x)
+
+
+ at register_attribute_builder("UI1Attr")
+def _ui1Attr(x, context):
+    return IntegerAttr.get(IntegerType.get_unsigned(1, context=context), x)
+
+
+ at register_attribute_builder("UI8Attr")
+def _i8Attr(x, context):
+    return IntegerAttr.get(IntegerType.get_unsigned(8, context=context), x)
+
+
+ at register_attribute_builder("UI16Attr")
+def _ui16Attr(x, context):
+    return IntegerAttr.get(IntegerType.get_unsigned(16, context=context), x)
+
+
+ at register_attribute_builder("UI32Attr")
+def _ui32Attr(x, context):
+    return IntegerAttr.get(IntegerType.get_unsigned(32, context=context), x)
+
+
+ at register_attribute_builder("UI64Attr")
+def _ui64Attr(x, context):
+    return IntegerAttr.get(IntegerType.get_unsigned(64, context=context), x)
+
+
 @register_attribute_builder("F32Attr")
 def _f32Attr(x, context):
     return FloatAttr.get_f32(x, context=context)
@@ -84,11 +144,39 @@ def _flatSymbolRefAttr(x, context):
     return FlatSymbolRefAttr.get(x, context=context)
 
 
+ at register_attribute_builder("UnitAttr")
+def _unitAttr(x, context):
+    if x:
+        return UnitAttr.get(context=context)
+    else:
+        return None
+
+
 @register_attribute_builder("ArrayAttr")
 def _arrayAttr(x, context):
     return ArrayAttr.get(x, context=context)
 
 
+ at register_attribute_builder("AffineMapArrayAttr")
+def _affineMapArrayAttr(x, context):
+    return ArrayAttr.get([_affineMapAttr(v, context) for v in x])
+
+
+ at register_attribute_builder("BoolArrayAttr")
+def _boolArrayAttr(x, context):
+    return ArrayAttr.get([_boolAttr(v, context) for v in x])
+
+
+ at register_attribute_builder("DictArrayAttr")
+def _dictArrayAttr(x, context):
+    return ArrayAttr.get([_dictAttr(v, context) for v in x])
+
+
+ at register_attribute_builder("FlatSymbolRefArrayAttr")
+def _flatSymbolRefArrayAttr(x, context):
+    return ArrayAttr.get([_flatSymbolRefAttr(v, context) for v in x])
+
+
 @register_attribute_builder("I32ArrayAttr")
 def _i32ArrayAttr(x, context):
     return ArrayAttr.get([_i32Attr(v, context) for v in x])
@@ -99,6 +187,16 @@ def _i64ArrayAttr(x, context):
     return ArrayAttr.get([_i64Attr(v, context) for v in x])
 
 
+ at register_attribute_builder("I64SmallVectorArrayAttr")
+def _i64SmallVectorArrayAttr(x, context):
+    return _i64ArrayAttr(x, context=context)
+
+
+ at register_attribute_builder("IndexListArrayAttr")
+def _indexListArrayAttr(x, context):
+    return ArrayAttr.get([_i64ArrayAttr(v, context) for v in x])
+
+
 @register_attribute_builder("F32ArrayAttr")
 def _f32ArrayAttr(x, context):
     return ArrayAttr.get([_f32Attr(v, context) for v in x])
@@ -109,6 +207,41 @@ def _f64ArrayAttr(x, context):
     return ArrayAttr.get([_f64Attr(v, context) for v in x])
 
 
+ at register_attribute_builder("StrArrayAttr")
+def _strArrayAttr(x, context):
+    return ArrayAttr.get([_stringAttr(v, context) for v in x])
+
+
+ at register_attribute_builder("SymbolRefArrayAttr")
+def _symbolRefArrayAttr(x, context):
+    return ArrayAttr.get([_symbolRefAttr(v, context) for v in x])
+
+
+ at register_attribute_builder("DenseF32ArrayAttr")
+def _denseF32ArrayAttr(x, context):
+    return DenseF32ArrayAttr.get(x, context=context)
+
+
+ at register_attribute_builder("DenseF64ArrayAttr")
+def _denseF64ArrayAttr(x, context):
+    return DenseF64ArrayAttr.get(x, context=context)
+
+
+ at register_attribute_builder("DenseI8ArrayAttr")
+def _denseI8ArrayAttr(x, context):
+    return DenseI8ArrayAttr.get(x, context=context)
+
+
+ at register_attribute_builder("DenseI16ArrayAttr")
+def _denseI16ArrayAttr(x, context):
+    return DenseI16ArrayAttr.get(x, context=context)
+
+
+ at register_attribute_builder("DenseI32ArrayAttr")
+def _denseI32ArrayAttr(x, context):
+    return DenseI32ArrayAttr.get(x, context=context)
+
+
 @register_attribute_builder("DenseI64ArrayAttr")
 def _denseI64ArrayAttr(x, context):
     return DenseI64ArrayAttr.get(x, context=context)
@@ -132,6 +265,30 @@ def _typeArrayAttr(x, context):
 try:
     import numpy as np
 
+    @register_attribute_builder("F64ElementsAttr")
+    def _f64ElementsAttr(x, context):
+        return DenseElementsAttr.get(
+            np.array(x, dtype=np.int64),
+            type=F64Type.get(context=context),
+            context=context,
+        )
+
+    @register_attribute_builder("I32ElementsAttr")
+    def _i32ElementsAttr(x, context):
+        return DenseElementsAttr.get(
+            np.array(x, dtype=np.int32),
+            type=IntegerType.get_signed(32, context=context),
+            context=context,
+        )
+
+    @register_attribute_builder("I64ElementsAttr")
+    def _i64ElementsAttr(x, context):
+        return DenseElementsAttr.get(
+            np.array(x, dtype=np.int64),
+            type=IntegerType.get_signed(64, context=context),
+            context=context,
+        )
+
     @register_attribute_builder("IndexElementsAttr")
     def _indexElementsAttr(x, context):
         return DenseElementsAttr.get(

diff  --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index 37e508fb979efc..651e6554eebe8b 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -140,23 +140,76 @@ def testAttributes():
 def attrBuilder():
     with Context() as ctx, Location.unknown():
         ctx.allow_unregistered_dialects = True
+        # CHECK: python_test.attributes_op
         op = test.AttributesOp(
-            x_bool=True,
-            x_i16=1,
-            x_i32=2,
-            x_i64=3,
-            x_si16=-1,
-            x_si32=-2,
-            x_f32=1.5,
-            x_f64=2.5,
-            x_str="x_str",
-            x_i32_array=[1, 2, 3],
-            x_i64_array=[4, 5, 6],
-            x_f32_array=[1.5, -2.5, 3.5],
-            x_f64_array=[4.5, 5.5, -6.5],
-            x_i64_dense=[1, 2, 3, 4, 5, 6],
+            # CHECK-DAG: x_affinemap = affine_map<() -> (2)>
+            x_affinemap=AffineMap.get_constant(2),
+            # CHECK-DAG: x_affinemaparr = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
+            x_affinemaparr=[AffineMap.get_identity(3)],
+            # CHECK-DAG: x_arr = [true, "x"]
+            x_arr=[BoolAttr.get(True), StringAttr.get("x")],
+            x_boolarr=[False, True],  # CHECK-DAG: x_boolarr = [false, true]
+            x_bool=True,  # CHECK-DAG: x_bool = true
+            x_dboolarr=[True, False],  # CHECK-DAG: x_dboolarr = array<i1: true, false>
+            x_df16arr=[21, 22],  # CHECK-DAG: x_df16arr = array<i16: 21, 22>
+            # CHECK-DAG: x_df32arr = array<f32: 2.300000e+01, 2.400000e+01>
+            x_df32arr=[23, 24],
+            # CHECK-DAG: x_df64arr = array<f64: 2.500000e+01, 2.600000e+01>
+            x_df64arr=[25, 26],
+            x_di32arr=[0, 1],  # CHECK-DAG: x_di32arr = array<i32: 0, 1>
+            # CHECK-DAG: x_di64arr = array<i64: 1, 2>
+            x_di64arr=[1, 2],
+            x_di8arr=[2, 3],  # CHECK-DAG: x_di8arr = array<i8: 2, 3>
+            # CHECK-DAG: x_dictarr = [{a = false}]
+            x_dictarr=[{"a": BoolAttr.get(False)}],
+            x_dict={"b": BoolAttr.get(True)},  # CHECK-DAG: x_dict = {b = true}
+            x_f32=-2.25,  # CHECK-DAG: x_f32 = -2.250000e+00 : f32
+            # CHECK-DAG: x_f32arr = [2.000000e+00 : f32, 3.000000e+00 : f32]
+            x_f32arr=[2.0, 3.0],
+            x_f64=4.25,  # CHECK-DAG: x_f64 = 4.250000e+00 : f64
+            x_f64arr=[4.0, 8.0],  # CHECK-DAG: x_f64arr = [4.000000e+00, 8.000000e+00]
+            # CHECK-DAG: x_f64elems = dense<[3.952530e-323, 7.905050e-323]> : tensor<2xf64>
+            x_f64elems=[8.0, 16.0],
+            # CHECK-DAG: x_flatsymrefarr = [@symbol1, @symbol2]
+            x_flatsymrefarr=["symbol1", "symbol2"],
+            x_flatsymref="symbol3",  # CHECK-DAG: x_flatsymref = @symbol3
+            x_i1=0,  # CHECK-DAG: x_i1 = false
+            x_i16=42,  # CHECK-DAG: x_i16 = 42 : i16
+            x_i32=6,  # CHECK-DAG: x_i32 = 6 : i32
+            x_i32arr=[4, 5],  # CHECK-DAG: x_i32arr = [4 : i32, 5 : i32]
+            x_i32elems=[5, 6],  # CHECK-DAG: x_i32elems = dense<[5, 6]> : tensor<2xsi32>
+            x_i64=9,  # CHECK-DAG: x_i64 = 9 : i64
+            x_i64arr=[7, 8],  # CHECK-DAG: x_i64arr = [7, 8]
+            x_i64elems=[8, 9],  # CHECK-DAG: x_i64elems = dense<[8, 9]> : tensor<2xsi64>
+            x_i64svecarr=[10, 11],  # CHECK-DAG: x_i64svecarr = [10, 11]
+            x_i8=11,  # CHECK-DAG: x_i8 = 11 : i8
+            x_idx=10,  # CHECK-DAG: x_idx = 10 : index
+            # CHECK-DAG: x_idxelems = dense<[11, 12]> : tensor<2xindex>
+            x_idxelems=[11, 12],
+            # CHECK-DAG: x_idxlistarr = [{{\[}}13], [14, 15]]
+            x_idxlistarr=[[13], [14, 15]],
+            x_si1=-1,  # CHECK-DAG: x_si1 = -1 : si1
+            x_si16=-2,  # CHECK-DAG: x_si16 = -2 : si16
+            x_si32=-3,  # CHECK-DAG: x_si32 = -3 : si32
+            x_si64=-123,  # CHECK-DAG: x_si64 = -123 : si64
+            x_si8=-4,  # CHECK-DAG: x_si8 = -4 : si8
+            x_strarr=["hello", "world"],  # CHECK-DAG: x_strarr = ["hello", "world"]
+            x_str="hello world!",  # CHECK-DAG: x_str = "hello world!"
+            # CHECK-DAG: x_symrefarr = [@flatsym, @deep::@sym]
+            x_symrefarr=["flatsym", ["deep", "sym"]],
+            x_symref=["deep", "sym2"],  # CHECK-DAG: x_symref = @deep::@sym2
+            x_sym="symbol",  # CHECK-DAG: x_sym = "symbol"
+            x_typearr=[F32Type.get()],  # CHECK-DAG: x_typearr = [f32]
+            x_type=F64Type.get(),  # CHECK-DAG: x_type = f64
+            x_ui1=1,  # CHECK-DAG: x_ui1 = 1 : ui1
+            x_ui16=2,  # CHECK-DAG: x_ui16 = 2 : ui16
+            x_ui32=3,  # CHECK-DAG: x_ui32 = 3 : ui32
+            x_ui64=4,  # CHECK-DAG: x_ui64 = 4 : ui64
+            x_ui8=5,  # CHECK-DAG: x_ui8 = 5 : ui8
+            x_unit=True,  # CHECK-DAG: x_unit
         )
-        print(op)
+        op.verify()
+        op.print(use_local_scope=True)
 
 
 # CHECK-LABEL: TEST: inferReturnTypes
@@ -247,7 +300,6 @@ def testOptionalOperandOp():
 
         module = Module.create()
         with InsertionPoint(module.body):
-
             op1 = test.OptionalOperandOp()
             # CHECK: op1.input is None: True
             print(f"op1.input is None: {op1.input is None}")

diff  --git a/mlir/test/python/python_test_ops.td b/mlir/test/python/python_test_ops.td
index 2adbdcab71834d..d79714301ae951 100644
--- a/mlir/test/python/python_test_ops.td
+++ b/mlir/test/python/python_test_ops.td
@@ -57,20 +57,60 @@ def AttributedOp : TestOp<"attributed_op"> {
 }
 
 def AttributesOp : TestOp<"attributes_op"> {
-  let arguments = (ins BoolAttr:$x_bool,
-                   I16Attr: $x_i16,
-                   I32Attr: $x_i32,
-                   I64Attr: $x_i64,
-                   SI16Attr: $x_si16,
-                   SI32Attr: $x_si32,
-                   F32Attr: $x_f32,
-                   F64Attr: $x_f64,
-                   StrAttr: $x_str,
-                   I32ArrayAttr: $x_i32_array,
-                   I64ArrayAttr: $x_i64_array,
-                   F32ArrayAttr: $x_f32_array,
-                   F64ArrayAttr: $x_f64_array,
-                   DenseI64ArrayAttr: $x_i64_dense);
+  let arguments = (ins
+                   AffineMapArrayAttr:$x_affinemaparr,
+                   AffineMapAttr:$x_affinemap,
+                   ArrayAttr:$x_arr,
+                   BoolArrayAttr:$x_boolarr,
+                   BoolAttr:$x_bool,
+                   DenseBoolArrayAttr:$x_dboolarr,
+                   DenseF32ArrayAttr:$x_df32arr,
+                   DenseF64ArrayAttr:$x_df64arr,
+                   DenseI16ArrayAttr:$x_df16arr,
+                   DenseI32ArrayAttr:$x_di32arr,
+                   DenseI64ArrayAttr:$x_di64arr,
+                   DenseI8ArrayAttr:$x_di8arr,
+                   DictArrayAttr:$x_dictarr,
+                   DictionaryAttr:$x_dict,
+                   F32ArrayAttr:$x_f32arr,
+                   F32Attr:$x_f32,
+                   F64ArrayAttr:$x_f64arr,
+                   F64Attr:$x_f64,
+                   F64ElementsAttr:$x_f64elems,
+                   FlatSymbolRefArrayAttr:$x_flatsymrefarr,
+                   FlatSymbolRefAttr:$x_flatsymref,
+                   I16Attr:$x_i16,
+                   I1Attr:$x_i1,
+                   I32ArrayAttr:$x_i32arr,
+                   I32Attr:$x_i32,
+                   I32ElementsAttr:$x_i32elems,
+                   I64ArrayAttr:$x_i64arr,
+                   I64Attr:$x_i64,
+                   I64ElementsAttr:$x_i64elems,
+                   I64SmallVectorArrayAttr:$x_i64svecarr,
+                   I8Attr:$x_i8,
+                   IndexAttr:$x_idx,
+                   IndexElementsAttr:$x_idxelems,
+                   IndexListArrayAttr:$x_idxlistarr,
+                   SI16Attr:$x_si16,
+                   SI1Attr:$x_si1,
+                   SI32Attr:$x_si32,
+                   SI64Attr:$x_si64,
+                   SI8Attr:$x_si8,
+                   StrArrayAttr:$x_strarr,
+                   StrAttr:$x_str,
+                   SymbolNameAttr:$x_sym,
+                   SymbolRefArrayAttr:$x_symrefarr,
+                   SymbolRefAttr:$x_symref,
+                   TypeArrayAttr:$x_typearr,
+                   TypeAttr:$x_type,
+                   UI16Attr:$x_ui16,
+                   UI1Attr:$x_ui1,
+                   UI32Attr:$x_ui32,
+                   UI64Attr:$x_ui64,
+                   UI8Attr:$x_ui8,
+                   UnitAttr:$x_unit
+                   );
 }
 
 def PropertyOp : TestOp<"property_op"> {


        


More information about the Mlir-commits mailing list