[Mlir-commits] [mlir] [MLIR][DataLayout] Add support for scalable vectors (PR #89349)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Apr 18 23:27:45 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Christian Ulmann (Dinistro)

<details>
<summary>Changes</summary>

This commit extends the data layout to support scalable vectors. For scalable vectors, the `TypeSize`'s scalable field is set accordingly, and the alignment information remains the same as for normal vectors. This behavior is in sync with what LLVM's data layout queries are producing.

---
Full diff: https://github.com/llvm/llvm-project/pull/89349.diff


3 Files Affected:

- (modified) mlir/lib/Interfaces/DataLayoutInterfaces.cpp (+10-6) 
- (modified) mlir/test/Interfaces/DataLayoutInterfaces/query.mlir (+12) 
- (modified) mlir/test/lib/Dialect/DLTI/TestDataLayoutQuery.cpp (+9-2) 


``````````diff
diff --git a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
index e93a9efbb76c17..ecf83c020799ee 100644
--- a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
+++ b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
@@ -75,10 +75,12 @@ mlir::detail::getDefaultTypeSizeInBits(Type type, const DataLayout &dataLayout,
   // there is no bit-packing at the moment element sizes are taken in bytes and
   // multiplied with 8 bits.
   // TODO: make this extensible.
-  if (auto vecType = dyn_cast<VectorType>(type))
-    return vecType.getNumElements() / vecType.getShape().back() *
-           llvm::PowerOf2Ceil(vecType.getShape().back()) *
-           dataLayout.getTypeSize(vecType.getElementType()) * 8;
+  if (auto vecType = dyn_cast<VectorType>(type)) {
+    uint64_t baseSize = vecType.getNumElements() / vecType.getShape().back() *
+                        llvm::PowerOf2Ceil(vecType.getShape().back()) *
+                        dataLayout.getTypeSize(vecType.getElementType()) * 8;
+    return llvm::TypeSize::get(baseSize, vecType.isScalable());
+  }
 
   if (auto typeInterface = dyn_cast<DataLayoutTypeInterface>(type))
     return typeInterface.getTypeSizeInBits(dataLayout, params);
@@ -138,9 +140,11 @@ getFloatTypeABIAlignment(FloatType fltType, const DataLayout &dataLayout,
 uint64_t mlir::detail::getDefaultABIAlignment(
     Type type, const DataLayout &dataLayout,
     ArrayRef<DataLayoutEntryInterface> params) {
-  // Natural alignment is the closest power-of-two number above.
+  // Natural alignment is the closest power-of-two number above. For scalable
+  // vectors, aligning them to the same as the base vector is sufficient.
+  // This should be consitent with LLVM.
   if (isa<VectorType>(type))
-    return llvm::PowerOf2Ceil(dataLayout.getTypeSize(type));
+    return llvm::PowerOf2Ceil(dataLayout.getTypeSize(type).getKnownMinValue());
 
   if (auto fltType = dyn_cast<FloatType>(type))
     return getFloatTypeABIAlignment(fltType, dataLayout, params);
diff --git a/mlir/test/Interfaces/DataLayoutInterfaces/query.mlir b/mlir/test/Interfaces/DataLayoutInterfaces/query.mlir
index d3bc91339d164b..11c06c38452c85 100644
--- a/mlir/test/Interfaces/DataLayoutInterfaces/query.mlir
+++ b/mlir/test/Interfaces/DataLayoutInterfaces/query.mlir
@@ -32,6 +32,18 @@ func.func @no_layout_builtin() {
   // CHECK: preferred = 8
   // CHECK: size = 8
   "test.data_layout_query"() : () -> index
+  // CHECK: alignment = 16
+  // CHECK: bitsize = 128
+  // CHECK: index = 0
+  // CHECK: preferred = 16
+  // CHECK: size = 16
+  "test.data_layout_query"() : () -> vector<4xi32>
+  // CHECK: alignment = 16
+  // CHECK: bitsize = "scalable"
+  // CHECK: index = 0
+  // CHECK: preferred = 16
+  // CHECK: size = "scalable"
+  "test.data_layout_query"() : () -> vector<[4]xi32>
   return
 
 }
diff --git a/mlir/test/lib/Dialect/DLTI/TestDataLayoutQuery.cpp b/mlir/test/lib/Dialect/DLTI/TestDataLayoutQuery.cpp
index a4464bba7e8584..b5eb2d99e71140 100644
--- a/mlir/test/lib/Dialect/DLTI/TestDataLayoutQuery.cpp
+++ b/mlir/test/lib/Dialect/DLTI/TestDataLayoutQuery.cpp
@@ -46,9 +46,16 @@ struct TestDataLayoutQuery
       Attribute programMemorySpace = layout.getProgramMemorySpace();
       Attribute globalMemorySpace = layout.getGlobalMemorySpace();
       uint64_t stackAlignment = layout.getStackAlignment();
+
+      auto convertTypeSizeToAttr = [&](llvm::TypeSize typeSize) -> Attribute {
+        if (typeSize.isScalable())
+          return builder.getStringAttr("scalable");
+        return builder.getIndexAttr(typeSize);
+      };
+
       op->setAttrs(
-          {builder.getNamedAttr("size", builder.getIndexAttr(size)),
-           builder.getNamedAttr("bitsize", builder.getIndexAttr(bitsize)),
+          {builder.getNamedAttr("size", convertTypeSizeToAttr(size)),
+           builder.getNamedAttr("bitsize", convertTypeSizeToAttr(bitsize)),
            builder.getNamedAttr("alignment", builder.getIndexAttr(alignment)),
            builder.getNamedAttr("preferred", builder.getIndexAttr(preferred)),
            builder.getNamedAttr("index", builder.getIndexAttr(index)),

``````````

</details>


https://github.com/llvm/llvm-project/pull/89349


More information about the Mlir-commits mailing list