[Mlir-commits] [mlir] 6c072c0 - [mlir][spirv] Fix verification and serialization replicated constant … (#151168)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Aug 1 10:02:22 PDT 2025


Author: Mohammadreza Ameri Mahabadian
Date: 2025-08-01T13:02:19-04:00
New Revision: 6c072c06cc59b56f83014450afe82f03666407ad

URL: https://github.com/llvm/llvm-project/commit/6c072c06cc59b56f83014450afe82f03666407ad
DIFF: https://github.com/llvm/llvm-project/commit/6c072c06cc59b56f83014450afe82f03666407ad.diff

LOG: [mlir][spirv] Fix verification and serialization replicated constant … (#151168)

…composites of multi-dimensional array

This fixes a bug in verification and serialization of replicated
constant composite ops where the splat value can potentially be a
multi-dimensional array.

---------

Signed-off-by: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian at arm.com>

Added: 
    

Modified: 
    mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
    mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
    mlir/test/Target/SPIRV/constant.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 52c672a05fa43..f99339852824c 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -767,19 +767,25 @@ void mlir::spirv::AddressOfOp::getAsmResultNames(
 // spirv.EXTConstantCompositeReplicate
 //===----------------------------------------------------------------------===//
 
+// Returns type of attribute. In case of a TypedAttr this will simply return
+// the type. But for an ArrayAttr which is untyped and can be multidimensional
+// it creates the ArrayType recursively.
+static Type getValueType(Attribute attr) {
+  if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
+    return typedAttr.getType();
+  }
+
+  if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
+    return spirv::ArrayType::get(getValueType(arrayAttr[0]), arrayAttr.size());
+  }
+
+  return nullptr;
+}
+
 LogicalResult spirv::EXTConstantCompositeReplicateOp::verify() {
-  Type valueType;
-  if (auto typedAttr = dyn_cast<TypedAttr>(getValue())) {
-    valueType = typedAttr.getType();
-  } else if (auto arrayAttr = dyn_cast<ArrayAttr>(getValue())) {
-    auto typedElemAttr = dyn_cast<TypedAttr>(arrayAttr[0]);
-    if (!typedElemAttr)
-      return emitError("value attribute is not typed");
-    valueType =
-        spirv::ArrayType::get(typedElemAttr.getType(), arrayAttr.size());
-  } else {
+  Type valueType = getValueType(getValue());
+  if (!valueType)
     return emitError("unknown value attribute type");
-  }
 
   auto compositeType = dyn_cast<spirv::CompositeType>(getType());
   if (!compositeType)

diff  --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 59665ec1add54..30536638b56f7 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -1187,6 +1187,21 @@ uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
   return resultID;
 }
 
+// Returns type of attribute. In case of a TypedAttr this will simply return
+// the type. But for an ArrayAttr which is untyped and can be multidimensional
+// it creates the ArrayType recursively.
+static Type getValueType(Attribute attr) {
+  if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
+    return typedAttr.getType();
+  }
+
+  if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
+    return spirv::ArrayType::get(getValueType(arrayAttr[0]), arrayAttr.size());
+  }
+
+  return nullptr;
+}
+
 uint32_t Serializer::prepareConstantCompositeReplicate(Location loc,
                                                        Type resultType,
                                                        Attribute valueAttr) {
@@ -1200,18 +1215,9 @@ uint32_t Serializer::prepareConstantCompositeReplicate(Location loc,
     return 0;
   }
 
-  Type valueType;
-  if (auto typedAttr = dyn_cast<TypedAttr>(valueAttr)) {
-    valueType = typedAttr.getType();
-  } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
-    auto typedElemAttr = dyn_cast<TypedAttr>(arrayAttr[0]);
-    if (!typedElemAttr)
-      return 0;
-    valueType =
-        spirv::ArrayType::get(typedElemAttr.getType(), arrayAttr.size());
-  } else {
+  Type valueType = getValueType(valueAttr);
+  if (!valueAttr)
     return 0;
-  }
 
   auto compositeType = dyn_cast<CompositeType>(resultType);
   if (!compositeType)

diff  --git a/mlir/test/Target/SPIRV/constant.mlir b/mlir/test/Target/SPIRV/constant.mlir
index 3be49eefcaebf..c81ceac072bd0 100644
--- a/mlir/test/Target/SPIRV/constant.mlir
+++ b/mlir/test/Target/SPIRV/constant.mlir
@@ -405,6 +405,13 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompos
     spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
   }
 
+  // CHECK-LABEL: @splat_array_of_non_splat_array_of_arrays_of_i32
+  spirv.func @splat_array_of_non_splat_array_of_arrays_of_i32() -> !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>> "None" {
+    // CHECK: spirv.EXT.ConstantCompositeReplicate {{\[}}{{\[}}[1 : i32, 2 : i32, 3 : i32], [4 : i32, 5 : i32, 6 : i32]]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>>
+    %0 = spirv.EXT.ConstantCompositeReplicate [[[1 : i32, 2 : i32, 3 : i32], [4 : i32, 5 : i32, 6 : i32]]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>>
+    spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>>
+  }
+
   // CHECK-LABEL: @null_cc_arm_tensor_of_i32
   spirv.func @null_cc_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
     // CHECK: spirv.Constant dense<0> : !spirv.arm.tensor<2x3xi32>
@@ -461,6 +468,13 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompos
     spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
   }
 
+  // CHECK-LABEL: @splat_array_of_non_splat_array_of_arrays_of_f32
+  spirv.func @splat_array_of_non_splat_array_of_arrays_of_f32() -> !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>> "None" {
+    // CHECK: spirv.EXT.ConstantCompositeReplicate {{\[}}{{\[}}[1.000000e+00 : f32, 2.000000e+00 : f32, 3.000000e+00 : f32], [4.000000e+00 : f32, 5.000000e+00 : f32, 6.000000e+00 : f32]]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>>
+    %0 = spirv.EXT.ConstantCompositeReplicate [[[1.0 : f32, 2.0 : f32, 3.0 : f32], [4.0 : f32, 5.0 : f32, 6.0 : f32]]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>>
+    spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>>
+  }
+
   // CHECK-LABEL: @null_cc_arm_tensor_of_f32
   spirv.func @null_cc_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
     // CHECK: spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<2x3xf32>


        


More information about the Mlir-commits mailing list