[Mlir-commits] [mlir] b53fd9c - [MLIR] Add getSizeInBits() for tensor of complex

Tim Shen llvmlistbot at llvm.org
Fri Aug 7 12:39:09 PDT 2020


Author: Tim Shen
Date: 2020-08-07T12:38:49-07:00
New Revision: b53fd9cdba4da51284941fdecfe3c7490d6013cc

URL: https://github.com/llvm/llvm-project/commit/b53fd9cdba4da51284941fdecfe3c7490d6013cc
DIFF: https://github.com/llvm/llvm-project/commit/b53fd9cdba4da51284941fdecfe3c7490d6013cc.diff

LOG: [MLIR] Add getSizeInBits() for tensor of complex

Differential Revision: https://reviews.llvm.org/D85382

Added: 
    

Modified: 
    mlir/lib/IR/StandardTypes.cpp
    mlir/test/mlir-tblgen/op-derived-attribute.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp
index f878672cd912..f4bb79362ffd 100644
--- a/mlir/lib/IR/StandardTypes.cpp
+++ b/mlir/lib/IR/StandardTypes.cpp
@@ -230,6 +230,11 @@ int64_t ShapedType::getSizeInBits() const {
   if (elementType.isIntOrFloat())
     return elementType.getIntOrFloatBitWidth() * getNumElements();
 
+  if (auto complexType = elementType.dyn_cast<ComplexType>()) {
+    elementType = complexType.getElementType();
+    return elementType.getIntOrFloatBitWidth() * getNumElements() * 2;
+  }
+
   // Tensors can have vectors and other tensors as elements, other shaped types
   // cannot.
   assert(isa<TensorType>() && "unsupported element type");

diff  --git a/mlir/test/mlir-tblgen/op-derived-attribute.mlir b/mlir/test/mlir-tblgen/op-derived-attribute.mlir
index ec4f4dcf7dae..3312fd54811b 100644
--- a/mlir/test/mlir-tblgen/op-derived-attribute.mlir
+++ b/mlir/test/mlir-tblgen/op-derived-attribute.mlir
@@ -5,9 +5,14 @@ func @verifyDerivedAttributes() {
   // expected-remark @+2 {{element_dtype = f32}}
   // expected-remark @+1 {{size = 320}}
   %0 = "test.derived_type_attr"() : () -> tensor<10xf32>
+
   // expected-remark @+2 {{element_dtype = i79}}
   // expected-remark @+1 {{size = 948}}
   %1 = "test.derived_type_attr"() : () -> tensor<12xi79>
 
+  // expected-remark @+2 {{element_dtype = complex<f32>}}
+  // expected-remark @+1 {{size = 768}}
+  %2 = "test.derived_type_attr"() : () -> tensor<12xcomplex<f32>>
+
   return
 }


        


More information about the Mlir-commits mailing list