[Mlir-commits] [mlir] d1cb685 - [mlir][IR] Remove ShapedType::getSizeInBits
Matthias Springer
llvmlistbot at llvm.org
Tue Apr 18 19:13:29 PDT 2023
Author: Matthias Springer
Date: 2023-04-19T11:01:33+09:00
New Revision: d1cb68525c4a0127f2d823d2c8c49791f55d3553
URL: https://github.com/llvm/llvm-project/commit/d1cb68525c4a0127f2d823d2c8c49791f55d3553
DIFF: https://github.com/llvm/llvm-project/commit/d1cb68525c4a0127f2d823d2c8c49791f55d3553.diff
LOG: [mlir][IR] Remove ShapedType::getSizeInBits
This function returns incorrect values for memrefs and vectors due to "widening".
Differential Revision: https://reviews.llvm.org/D148501
Added:
Modified:
mlir/include/mlir/IR/BuiltinTypeInterfaces.td
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
mlir/lib/IR/BuiltinTypeInterfaces.cpp
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/mlir-tblgen/op-derived-attribute.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index f2b1fa34bc391..bb38985715c09 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -99,15 +99,6 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
/// Return the number of elements present in the given shape.
static int64_t getNumElements(ArrayRef<int64_t> shape);
-
- /// Returns the total amount of bits occupied by a value of this type. This
- /// does not take into account any memory layout or widening constraints,
- /// e.g. a vector<3xi57> may report to occupy 3x57=171 bit, even though in
- /// practice it will likely be stored as in a 4xi64 vector register. Fails
- /// with an assertion if the size cannot be computed statically, e.g. if the
- /// type has a dynamic shape or if its elemental type does not have a known
- /// bit width.
- int64_t getSizeInBits() const;
}];
let extraSharedClassDeclaration = [{
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index dae90f26199a4..50017b7bcef9b 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -40,8 +40,11 @@ static uint64_t getFirstIntValue(ArrayAttr attr) {
/// Returns the number of bits for the given scalar/vector type.
static int getNumBits(Type type) {
+ // TODO: This does not take into account any memory layout or widening
+ // constraints. E.g., a vector<3xi57> may report to occupy 3x57=171 bit, even
+ // though in practice it will likely be stored as in a 4xi64 vector register.
if (auto vectorType = type.dyn_cast<VectorType>())
- return vectorType.cast<ShapedType>().getSizeInBits();
+ return vectorType.getNumElements() * vectorType.getElementTypeBitWidth();
return type.getIntOrFloatBitWidth();
}
diff --git a/mlir/lib/IR/BuiltinTypeInterfaces.cpp b/mlir/lib/IR/BuiltinTypeInterfaces.cpp
index 88791fc66fbf2..ab9e65b5edfed 100644
--- a/mlir/lib/IR/BuiltinTypeInterfaces.cpp
+++ b/mlir/lib/IR/BuiltinTypeInterfaces.cpp
@@ -33,18 +33,3 @@ int64_t ShapedType::getNumElements(ArrayRef<int64_t> shape) {
}
return num;
}
-
-int64_t ShapedType::getSizeInBits() const {
- assert(hasStaticShape() &&
- "cannot get the bit size of an aggregate with a dynamic shape");
-
- auto elementType = getElementType();
- if (elementType.isIntOrFloat())
- return elementType.getIntOrFloatBitWidth() * getNumElements();
-
- if (auto complexType = elementType.dyn_cast<ComplexType>()) {
- elementType = complexType.getElementType();
- return elementType.getIntOrFloatBitWidth() * getNumElements() * 2;
- }
- return getNumElements() * elementType.cast<ShapedType>().getSizeInBits();
-}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 0306f0ed02f99..9381409f07844 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -259,8 +259,8 @@ def DerivedTypeAttrOp : TEST_Op<"derived_type_attr", []> {
let results = (outs AnyTensor:$output);
DerivedTypeAttr element_dtype =
DerivedTypeAttr<"return getElementTypeOrSelf(getOutput().getType());">;
- DerivedAttr size = DerivedAttr<"int",
- "return getOutput().getType().cast<ShapedType>().getSizeInBits();",
+ DerivedAttr num_elements = DerivedAttr<"int",
+ "return getOutput().getType().cast<ShapedType>().getNumElements();",
"$_builder.getI32IntegerAttr($_self)">;
}
diff --git a/mlir/test/mlir-tblgen/op-derived-attribute.mlir b/mlir/test/mlir-tblgen/op-derived-attribute.mlir
index 27fa3ab821ac0..2b0e6ed4994b6 100644
--- a/mlir/test/mlir-tblgen/op-derived-attribute.mlir
+++ b/mlir/test/mlir-tblgen/op-derived-attribute.mlir
@@ -3,15 +3,15 @@
// CHECK-LABEL: verifyDerivedAttributes
func.func @verifyDerivedAttributes() {
// expected-remark @+2 {{element_dtype = f32}}
- // expected-remark @+1 {{size = 320}}
+ // expected-remark @+1 {{num_elements = 10}}
%0 = "test.derived_type_attr"() : () -> tensor<10xf32>
// expected-remark @+2 {{element_dtype = i79}}
- // expected-remark @+1 {{size = 948}}
+ // expected-remark @+1 {{num_elements = 12}}
%1 = "test.derived_type_attr"() : () -> tensor<12xi79>
// expected-remark @+2 {{element_dtype = complex<f32>}}
- // expected-remark @+1 {{size = 768}}
+ // expected-remark @+1 {{num_elements = 12}}
%2 = "test.derived_type_attr"() : () -> tensor<12xcomplex<f32>>
return
More information about the Mlir-commits
mailing list