[Mlir-commits] [mlir] 99dee31 - Make it possible to create DenseElementsAttrs with arbitrary shaped types in Python bindings

Jacques Pienaar llvmlistbot at llvm.org
Wed Mar 8 11:12:28 PST 2023


Author: Adam Paszke
Date: 2023-03-08T11:11:45-08:00
New Revision: 99dee31ef48012f8984ddab806b7345c24b02a72

URL: https://github.com/llvm/llvm-project/commit/99dee31ef48012f8984ddab806b7345c24b02a72
DIFF: https://github.com/llvm/llvm-project/commit/99dee31ef48012f8984ddab806b7345c24b02a72.diff

LOG: Make it possible to create DenseElementsAttrs with arbitrary shaped types in Python bindings

Right now the bindings assume that all DenseElementsAttrs correspond to tensor values,
making it impossible to create vector-typed constants. I didn't want to change the API
significantly, so I opted for reusing the current signature of `.get`. Its `type` argument
now accepts both element types (in which case `shape` and `signless` can be specified too),
or a shaped type, which specifies the full type of the created attr (`shape` cannot be specified
in that case).

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D145053

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/IRAttributes.cpp
    mlir/test/python/dialects/builtin.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index b0c35ffb8a53f..c59a54b6699a7 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -624,8 +624,17 @@ class PyDenseElementsAttribute
       }
     }
     if (bulkLoadElementType) {
-      auto shapedType = mlirRankedTensorTypeGet(
-          shape.size(), shape.data(), *bulkLoadElementType, encodingAttr);
+      MlirType shapedType;
+      if (mlirTypeIsAShaped(*bulkLoadElementType)) {
+        if (explicitShape) {
+          throw std::invalid_argument("Shape can only be specified explicitly "
+                                      "when the type is not a shaped type.");
+        }
+        shapedType = *bulkLoadElementType;
+      } else {
+        shapedType = mlirRankedTensorTypeGet(
+            shape.size(), shape.data(), *bulkLoadElementType, encodingAttr);
+      }
       size_t rawBufferSize = arrayInfo.size * arrayInfo.itemsize;
       MlirAttribute attr = mlirDenseElementsAttrRawBufferGet(
           shapedType, rawBufferSize, arrayInfo.ptr);

diff  --git a/mlir/test/python/dialects/builtin.py b/mlir/test/python/dialects/builtin.py
index 94e29892ba7b1..eab24b5c796b0 100644
--- a/mlir/test/python/dialects/builtin.py
+++ b/mlir/test/python/dialects/builtin.py
@@ -3,6 +3,7 @@
 from mlir.ir import *
 import mlir.dialects.builtin as builtin
 import mlir.dialects.func as func
+import numpy as np
 
 
 def run(f):
@@ -221,3 +222,17 @@ def testFuncArgumentAccess():
   # CHECK: %{{.*}}: f32 {custom_dialect.foo = "qux"},
   # CHECK: %{{.*}}: f32)
   print(module)
+
+
+# CHECK-LABEL: testDenseElementsAttr
+ at run
+def testDenseElementsAttr():
+  with Context(), Location.unknown():
+    values = np.arange(4, dtype=np.int32)
+    i32 = IntegerType.get_signless(32)
+    print(DenseElementsAttr.get(values, type=i32))
+    # CHECK{LITERAL}: dense<[0, 1, 2, 3]> : tensor<4xi32>
+    print(DenseElementsAttr.get(values, type=i32, shape=(2, 2)))
+    # CHECK{LITERAL}: dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>
+    print(DenseElementsAttr.get(values, type=VectorType.get((2, 2), i32)))
+    # CHECK{LITERAL}: dense<[[0, 1], [2, 3]]> : vector<2x2xi32>


        


More information about the Mlir-commits mailing list