[Mlir-commits] [mlir] d1cb685 - [mlir][IR] Remove ShapedType::getSizeInBits

Matthias Springer llvmlistbot at llvm.org
Tue Apr 18 19:13:29 PDT 2023


Author: Matthias Springer
Date: 2023-04-19T11:01:33+09:00
New Revision: d1cb68525c4a0127f2d823d2c8c49791f55d3553

URL: https://github.com/llvm/llvm-project/commit/d1cb68525c4a0127f2d823d2c8c49791f55d3553
DIFF: https://github.com/llvm/llvm-project/commit/d1cb68525c4a0127f2d823d2c8c49791f55d3553.diff

LOG: [mlir][IR] Remove ShapedType::getSizeInBits

This function returns incorrect values for memrefs and vectors due to "widening".

Differential Revision: https://reviews.llvm.org/D148501

Added: 
    

Modified: 
    mlir/include/mlir/IR/BuiltinTypeInterfaces.td
    mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
    mlir/lib/IR/BuiltinTypeInterfaces.cpp
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/mlir-tblgen/op-derived-attribute.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index f2b1fa34bc391..bb38985715c09 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -99,15 +99,6 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
 
     /// Return the number of elements present in the given shape.
     static int64_t getNumElements(ArrayRef<int64_t> shape);
-
-    /// Returns the total amount of bits occupied by a value of this type. This
-    /// does not take into account any memory layout or widening constraints,
-    /// e.g. a vector<3xi57> may report to occupy 3x57=171 bit, even though in
-    /// practice it will likely be stored as in a 4xi64 vector register. Fails
-    /// with an assertion if the size cannot be computed statically, e.g. if the
-    /// type has a dynamic shape or if its elemental type does not have a known
-    /// bit width.
-    int64_t getSizeInBits() const;
   }];
 
   let extraSharedClassDeclaration = [{

diff  --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index dae90f26199a4..50017b7bcef9b 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -40,8 +40,11 @@ static uint64_t getFirstIntValue(ArrayAttr attr) {
 
 /// Returns the number of bits for the given scalar/vector type.
 static int getNumBits(Type type) {
+  // TODO: This does not take into account any memory layout or widening
+  // constraints. E.g., a vector<3xi57> may report to occupy 3x57=171 bit, even
+  // though in practice it will likely be stored as in a 4xi64 vector register.
   if (auto vectorType = type.dyn_cast<VectorType>())
-    return vectorType.cast<ShapedType>().getSizeInBits();
+    return vectorType.getNumElements() * vectorType.getElementTypeBitWidth();
   return type.getIntOrFloatBitWidth();
 }
 

diff  --git a/mlir/lib/IR/BuiltinTypeInterfaces.cpp b/mlir/lib/IR/BuiltinTypeInterfaces.cpp
index 88791fc66fbf2..ab9e65b5edfed 100644
--- a/mlir/lib/IR/BuiltinTypeInterfaces.cpp
+++ b/mlir/lib/IR/BuiltinTypeInterfaces.cpp
@@ -33,18 +33,3 @@ int64_t ShapedType::getNumElements(ArrayRef<int64_t> shape) {
   }
   return num;
 }
-
-int64_t ShapedType::getSizeInBits() const {
-  assert(hasStaticShape() &&
-         "cannot get the bit size of an aggregate with a dynamic shape");
-
-  auto elementType = getElementType();
-  if (elementType.isIntOrFloat())
-    return elementType.getIntOrFloatBitWidth() * getNumElements();
-
-  if (auto complexType = elementType.dyn_cast<ComplexType>()) {
-    elementType = complexType.getElementType();
-    return elementType.getIntOrFloatBitWidth() * getNumElements() * 2;
-  }
-  return getNumElements() * elementType.cast<ShapedType>().getSizeInBits();
-}

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 0306f0ed02f99..9381409f07844 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -259,8 +259,8 @@ def DerivedTypeAttrOp : TEST_Op<"derived_type_attr", []> {
   let results = (outs AnyTensor:$output);
   DerivedTypeAttr element_dtype =
     DerivedTypeAttr<"return getElementTypeOrSelf(getOutput().getType());">;
-  DerivedAttr size = DerivedAttr<"int",
-    "return getOutput().getType().cast<ShapedType>().getSizeInBits();",
+  DerivedAttr num_elements = DerivedAttr<"int",
+    "return getOutput().getType().cast<ShapedType>().getNumElements();",
     "$_builder.getI32IntegerAttr($_self)">;
 }
 

diff  --git a/mlir/test/mlir-tblgen/op-derived-attribute.mlir b/mlir/test/mlir-tblgen/op-derived-attribute.mlir
index 27fa3ab821ac0..2b0e6ed4994b6 100644
--- a/mlir/test/mlir-tblgen/op-derived-attribute.mlir
+++ b/mlir/test/mlir-tblgen/op-derived-attribute.mlir
@@ -3,15 +3,15 @@
 // CHECK-LABEL: verifyDerivedAttributes
 func.func @verifyDerivedAttributes() {
   // expected-remark @+2 {{element_dtype = f32}}
-  // expected-remark @+1 {{size = 320}}
+  // expected-remark @+1 {{num_elements = 10}}
   %0 = "test.derived_type_attr"() : () -> tensor<10xf32>
 
   // expected-remark @+2 {{element_dtype = i79}}
-  // expected-remark @+1 {{size = 948}}
+  // expected-remark @+1 {{num_elements = 12}}
   %1 = "test.derived_type_attr"() : () -> tensor<12xi79>
 
   // expected-remark @+2 {{element_dtype = complex<f32>}}
-  // expected-remark @+1 {{size = 768}}
+  // expected-remark @+1 {{num_elements = 12}}
   %2 = "test.derived_type_attr"() : () -> tensor<12xcomplex<f32>>
 
   return


        


More information about the Mlir-commits mailing list