[Mlir-commits] [mlir] [mlir][LLVMIR] Check number of elements in `mlir.constant` verifier (PR #102906)

Matthias Springer llvmlistbot at llvm.org
Tue Aug 13 01:13:30 PDT 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/102906

>From 8678ad86ee9e57b868cce6b8722bf9472e510bd5 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Mon, 12 Aug 2024 16:06:22 +0200
Subject: [PATCH] [mlir][LLVMIR] Check number of elements in `mlir.constant`
 verifier

Check that the number of elements in the result type and the attribute of an `llvm.mlir.constant` op matches. Also fix a broken test where that was not the case.
---
 mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 13 ++--
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp  | 72 ++++++++++++++++++---
 mlir/test/Dialect/LLVMIR/invalid.mlir       | 16 +++++
 mlir/test/Target/LLVMIR/llvmir.mlir         |  8 ++-
 4 files changed, 94 insertions(+), 15 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index c38a2584c8eec1..643522d5903fd0 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -1623,10 +1623,15 @@ def LLVM_ConstantOp
     vectors. It has a mandatory `value` attribute, which may be an integer,
     floating point attribute; dense or sparse attribute containing integers or
     floats. The type of the attribute is one of the corresponding MLIR builtin
-    types. It may be omitted for `i64` and `f64` types that are implied. The
-    operation produces a new SSA value of the specified LLVM IR dialect type.
-    The type of that value _must_ correspond to the attribute type converted to
-    LLVM IR.
+    types. It may be omitted for `i64` and `f64` types that are implied.
+
+    The operation produces a new SSA value of the specified LLVM IR dialect
+    type. Certain builtin types such as integer, float and vector types are
+    also allowed. The result type _must_ correspond to the attribute type
+    converted to LLVM IR. In particular, the number of elements of a container
+    type must match the number of elements in the attribute. If the type is or
+    contains a scalable vector type, the attribute must be a splat elements
+    attribute.
 
     Examples:
 
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 90610118a45cd2..07262bb8e1bacb 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2666,6 +2666,39 @@ OpFoldResult LLVM::ZeroOp::fold(FoldAdaptor) {
 // ConstantOp.
 //===----------------------------------------------------------------------===//
 
+/// Compute the total number of elements in the given type, also taking into
+/// account nested types. Supported types are `VectorType`, `LLVMArrayType` and
+/// `LLVMFixedVectorType`. Everything else is treated as a scalar.
+static int64_t getNumElements(Type t) {
+  if (auto vecType = dyn_cast<VectorType>(t))
+    return vecType.getNumElements() * getNumElements(vecType.getElementType());
+  if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(t))
+    return arrayType.getNumElements() *
+           getNumElements(arrayType.getElementType());
+  if (auto vecType = dyn_cast<LLVMFixedVectorType>(t))
+    return vecType.getNumElements() * getNumElements(vecType.getElementType());
+  assert(!isa<LLVM::LLVMScalableVectorType>(t) &&
+         "number of elements of a scalable vector type is unknown");
+  return 1;
+}
+
+/// Check if the given type is a scalable vector type or a vector/array type
+/// that contains a nested scalable vector type.
+static bool hasScalableVectorType(Type t) {
+  if (isa<LLVM::LLVMScalableVectorType>(t))
+    return true;
+  if (auto vecType = dyn_cast<VectorType>(t)) {
+    if (vecType.isScalable())
+      return true;
+    return hasScalableVectorType(vecType.getElementType());
+  }
+  if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(t))
+    return hasScalableVectorType(arrayType.getElementType());
+  if (auto vecType = dyn_cast<LLVMFixedVectorType>(t))
+    return hasScalableVectorType(vecType.getElementType());
+  return false;
+}
+
 LogicalResult LLVM::ConstantOp::verify() {
   if (StringAttr sAttr = llvm::dyn_cast<StringAttr>(getValue())) {
     auto arrayType = llvm::dyn_cast<LLVMArrayType>(getType());
@@ -2708,14 +2741,12 @@ LogicalResult LLVM::ConstantOp::verify() {
   if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType())) {
     return emitOpError() << "does not support target extension type.";
   }
-  if (!llvm::isa<IntegerAttr, ArrayAttr, FloatAttr, ElementsAttr>(getValue()))
-    return emitOpError()
-           << "only supports integer, float, string or elements attributes";
+
+  // Verification of IntegerAttr, FloatAttr, ElementsAttr, ArrayAttr.
   if (auto intAttr = dyn_cast<IntegerAttr>(getValue())) {
     if (!llvm::isa<IntegerType>(getType()))
       return emitOpError() << "expected integer type";
-  }
-  if (auto floatAttr = dyn_cast<FloatAttr>(getValue())) {
+  } else if (auto floatAttr = dyn_cast<FloatAttr>(getValue())) {
     const llvm::fltSemantics &sem = floatAttr.getValue().getSemantics();
     unsigned floatWidth = APFloat::getSizeInBits(sem);
     if (auto floatTy = dyn_cast<FloatType>(getType())) {
@@ -2728,13 +2759,34 @@ LogicalResult LLVM::ConstantOp::verify() {
     if (isa<IntegerType>(getType()) && !getType().isInteger(floatWidth)) {
       return emitOpError() << "expected integer type of width " << floatWidth;
     }
-  }
-  if (auto splatAttr = dyn_cast<SplatElementsAttr>(getValue())) {
-    if (!isa<VectorType>(getType()) && !isa<LLVM::LLVMArrayType>(getType()) &&
-        !isa<LLVM::LLVMFixedVectorType>(getType()) &&
-        !isa<LLVM::LLVMScalableVectorType>(getType()))
+  } else if (isa<ElementsAttr, ArrayAttr>(getValue())) {
+    if (hasScalableVectorType(getType())) {
+      // The exact number of elements of a scalable vector is unknown, so we
+      // allow only splat attributes.
+      auto splatElementsAttr = dyn_cast<SplatElementsAttr>(getValue());
+      if (!splatElementsAttr)
+        return emitOpError()
+               << "scalable vector type requires a splat attribute";
+      return success();
+    }
+    if (!isa<VectorType, LLVM::LLVMArrayType, LLVM::LLVMFixedVectorType>(
+            getType()))
       return emitOpError() << "expected vector or array type";
+    // The number of elements of the attribute and the type must match.
+    int64_t attrNumElements;
+    if (auto elementsAttr = dyn_cast<ElementsAttr>(getValue()))
+      attrNumElements = elementsAttr.getNumElements();
+    else
+      attrNumElements = cast<ArrayAttr>(getValue()).size();
+    if (getNumElements(getType()) != attrNumElements)
+      return emitOpError()
+             << "type and attribute have a different number of elements: "
+             << getNumElements(getType()) << " vs. " << attrNumElements;
+  } else {
+    return emitOpError()
+           << "only supports integer, float, string or elements attributes";
   }
+
   return success();
 }
 
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index fe288dab973f5a..62346ce0d2c4b1 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -414,6 +414,22 @@ llvm.func @struct_wrong_element_types() -> !llvm.struct<(!llvm.array<2 x f64>, !
 
 // -----
 
+llvm.func @const_wrong_number_of_elements() -> vector<5xf64> {
+  // expected-error @+1{{type and attribute have a different number of elements: 5 vs. 2}}
+  %0 = llvm.mlir.constant(dense<[1.0, 1.0]> : tensor<2xf64>) : vector<5xf64>
+  llvm.return %0 : vector<5xf64>
+}
+
+// -----
+
+llvm.func @scalable_vec_requires_splat() -> vector<[4]xf64> {
+  // expected-error @+1{{scalable vector type requires a splat attribute}}
+  %0 = llvm.mlir.constant(dense<[1.0, 1.0, 2.0, 2.0]> : tensor<4xf64>) : vector<[4]xf64>
+  llvm.return %0 : vector<[4]xf64>
+}
+
+// -----
+
 func.func @insertvalue_non_llvm_type(%a : i32, %b : i32) {
   // expected-error at +2 {{expected LLVM IR Dialect type}}
   llvm.insertvalue %a, %b[0] : tensor<*xi32>
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index fbdf725f3ec17b..8453983aa07c33 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -1295,11 +1295,17 @@ llvm.func @complexintconstant() -> !llvm.struct<(i32, i32)> {
 }
 
 llvm.func @complexintconstantsplat() -> !llvm.array<2 x !llvm.struct<(i32, i32)>> {
-  %1 = llvm.mlir.constant(dense<(0, 1)> : tensor<complex<i32>>) : !llvm.array<2 x !llvm.struct<(i32, i32)>>
+  %1 = llvm.mlir.constant(dense<(0, 1)> : tensor<2xcomplex<i32>>) : !llvm.array<2 x !llvm.struct<(i32, i32)>>
   // CHECK: ret [2 x { i32, i32 }] [{ i32, i32 } { i32 0, i32 1 }, { i32, i32 } { i32 0, i32 1 }]
   llvm.return %1 : !llvm.array<2 x !llvm.struct<(i32, i32)>>
 }
 
+llvm.func @complexintconstantsingle() -> !llvm.array<1 x !llvm.struct<(i32, i32)>> {
+  %1 = llvm.mlir.constant(dense<(0, 1)> : tensor<complex<i32>>) : !llvm.array<1 x !llvm.struct<(i32, i32)>>
+  // CHECK: ret [1 x { i32, i32 }] [{ i32, i32 } { i32 0, i32 1 }]
+  llvm.return %1 : !llvm.array<1 x !llvm.struct<(i32, i32)>>
+}
+
 llvm.func @complexintconstantarray() -> !llvm.array<2 x !llvm.array<2 x !llvm.struct<(i32, i32)>>> {
   %1 = llvm.mlir.constant(dense<[[(0, 1), (2, 3)], [(4, 5), (6, 7)]]> : tensor<2x2xcomplex<i32>>) : !llvm.array<2 x!llvm.array<2 x !llvm.struct<(i32, i32)>>>
   // CHECK{LITERAL}: ret [2 x [2 x { i32, i32 }]] [[2 x { i32, i32 }] [{ i32, i32 } { i32 0, i32 1 }, { i32, i32 } { i32 2, i32 3 }], [2 x { i32, i32 }] [{ i32, i32 } { i32 4, i32 5 }, { i32, i32 } { i32 6, i32 7 }]]



More information about the Mlir-commits mailing list