[Mlir-commits] [mlir] [mlir][spirv] Add conversion pass to rewrite splat constant composite… (PR #148910)

Mohammadreza Ameri Mahabadian llvmlistbot at llvm.org
Fri Jul 18 06:28:04 PDT 2025


https://github.com/mahabadm updated https://github.com/llvm/llvm-project/pull/148910

>From 936435e95e1cab2e888eb8483bd3aeaf3f91857b Mon Sep 17 00:00:00 2001
From: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian at arm.com>
Date: Wed, 2 Jul 2025 09:00:45 +0100
Subject: [PATCH 1/9] [mlir][spirv] Add conversion pass to rewrite splat
 constant composites to replicated form

This adds a new SPIR-V dialect-level conversion pass `ConversionToReplicatedConstantCompositePass`.
This pass looks for splat composite `spirv.Constant` or `spirv.SpecConstantComposite` and rewrites them into `spirv.EXT.ConstantCompositeReplicate` or `spirv.EXT.SpecConstantCompositeReplicate`, respectively.

Signed-off-by: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian at arm.com>
---
 .../mlir/Dialect/SPIRV/Transforms/Passes.td   |   7 +
 .../Dialect/SPIRV/Transforms/CMakeLists.txt   |   2 +
 ...rsionToReplicatedConstantCompositePass.cpp | 135 ++++++++++++
 .../replicated-const-composites.mlir          | 192 ++++++++++++++++++
 4 files changed, 336 insertions(+)
 create mode 100644 mlir/lib/Dialect/SPIRV/Transforms/ConversionToReplicatedConstantCompositePass.cpp
 create mode 100644 mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir

diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td b/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td
index 2d9befe78001d..3c04db8396367 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td
@@ -77,4 +77,11 @@ def SPIRVWebGPUPreparePass : Pass<"spirv-webgpu-prepare", "spirv::ModuleOp"> {
                 "and replacing with supported ones";
 }
 
+def SPIRVReplicatedConstantCompositePass
+    : Pass<"spirv-replicated-const-composite", "spirv::ModuleOp"> {
+  let summary = "Convert splat composite constants and spec constants to"
+                "corresponding replicated constant composite ops defined by"
+                "SPV_EXT_replicated_composites";
+}
+
 #endif // MLIR_DIALECT_SPIRV_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
index 68e0206e30a59..c675af9d048cc 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
 set(LLVM_OPTIONAL_SOURCES
   CanonicalizeGLPass.cpp
+  ConversionToReplicatedConstantCompositePass.cpp
   DecorateCompositeTypeLayoutPass.cpp
   LowerABIAttributesPass.cpp
   RewriteInsertsPass.cpp
@@ -30,6 +31,7 @@ add_mlir_dialect_library(MLIRSPIRVConversion
 
 add_mlir_dialect_library(MLIRSPIRVTransforms
   CanonicalizeGLPass.cpp
+  ConversionToReplicatedConstantCompositePass.cpp
   DecorateCompositeTypeLayoutPass.cpp
   LowerABIAttributesPass.cpp
   RewriteInsertsPass.cpp
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/ConversionToReplicatedConstantCompositePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/ConversionToReplicatedConstantCompositePass.cpp
new file mode 100644
index 0000000000000..530a0f4aa67f5
--- /dev/null
+++ b/mlir/lib/Dialect/SPIRV/Transforms/ConversionToReplicatedConstantCompositePass.cpp
@@ -0,0 +1,135 @@
+//===- ConversionToReplicatedConstantCompositePass.cpp
+//---------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to convert a splat composite spirv.Constant and
+// spirv.SpecConstantComposite to spirv.EXT.ConstantCompositeReplicate and
+// spirv.EXT.SpecConstantCompositeReplicate respectively.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace spirv {
+#define GEN_PASS_DEF_SPIRVREPLICATEDCONSTANTCOMPOSITEPASS
+#include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
+} // namespace spirv
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+Attribute getSplatAttribute(Attribute valueAttr, uint32_t splatCount) {
+  Attribute attr;
+  if (auto denseAttr = dyn_cast<DenseElementsAttr>(valueAttr)) {
+    if (denseAttr.isSplat()) {
+      attr = denseAttr.getSplatValue<Attribute>();
+      splatCount = denseAttr.size();
+    }
+  } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
+    if (std::adjacent_find(arrayAttr.begin(), arrayAttr.end(),
+                           std::not_equal_to<>()) == arrayAttr.end()) {
+      attr = arrayAttr[0];
+      splatCount = arrayAttr.size();
+    }
+  }
+
+  if (attr) {
+    if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
+      if (isa<spirv::CompositeType>(typedAttr.getType()))
+        if (Attribute newAttr = getSplatAttribute(attr, splatCount))
+          attr = newAttr;
+    } else if (isa<ArrayAttr>(attr)) {
+      if (Attribute newAttr = getSplatAttribute(attr, splatCount))
+        attr = newAttr;
+    }
+  }
+
+  return attr;
+}
+
+} // namespace
+
+namespace {
+class ConversionToReplicatedConstantCompositePass
+    : public spirv::impl::SPIRVReplicatedConstantCompositePassBase<
+          ConversionToReplicatedConstantCompositePass> {
+public:
+  void runOnOperation() override;
+};
+
+class ConstantOpConversion : public OpRewritePattern<spirv::ConstantOp> {
+  using OpRewritePattern<spirv::ConstantOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(spirv::ConstantOp op,
+                                PatternRewriter &rewriter) const override {
+    auto compositeType = dyn_cast_or_null<spirv::CompositeType>(op.getType());
+    if (!compositeType)
+      return rewriter.notifyMatchFailure(op, "not a composite constant");
+
+    uint32_t splatCount = 0;
+    Attribute splatAttr = getSplatAttribute(op.getValue(), splatCount);
+    if (!splatAttr)
+      return rewriter.notifyMatchFailure(op, "composite is not splat");
+
+    if (splatCount == 1)
+      return rewriter.notifyMatchFailure(op,
+                                         "composite has only one consituent");
+
+    rewriter.replaceOpWithNewOp<spirv::EXTConstantCompositeReplicateOp>(
+        op, op.getType(), splatAttr);
+
+    return success();
+  }
+};
+
+class SpecConstantCompositeOpConversion
+    : public OpRewritePattern<spirv::SpecConstantCompositeOp> {
+  using OpRewritePattern<spirv::SpecConstantCompositeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(spirv::SpecConstantCompositeOp op,
+                                PatternRewriter &rewriter) const override {
+    auto compositeType = dyn_cast_or_null<spirv::CompositeType>(op.getType());
+    if (!compositeType)
+      return rewriter.notifyMatchFailure(op, "not a composite constant");
+
+    auto constituents = op.getConstituents();
+    if (constituents.size() == 1)
+      return rewriter.notifyMatchFailure(op,
+                                         "composite has only one consituent");
+
+    if (!(std::adjacent_find(constituents.begin(), constituents.end(),
+                             std::not_equal_to<>()) == constituents.end()))
+      return rewriter.notifyMatchFailure(op, "composite is not splat");
+
+    auto splatConstituent =
+        dyn_cast<FlatSymbolRefAttr>(op.getConstituents()[0]);
+
+    rewriter.replaceOpWithNewOp<spirv::EXTSpecConstantCompositeReplicateOp>(
+        op, TypeAttr::get(op.getType()), op.getSymNameAttr(), splatConstituent);
+
+    return success();
+  }
+};
+
+void ConversionToReplicatedConstantCompositePass::runOnOperation() {
+  MLIRContext *context = &getContext();
+  RewritePatternSet patterns(context);
+  patterns.add<ConstantOpConversion>(context);
+  patterns.add<SpecConstantCompositeOpConversion>(context);
+
+  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
+    signalPassFailure();
+  }
+}
+
+} // namespace
\ No newline at end of file
diff --git a/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir b/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir
new file mode 100644
index 0000000000000..f8cd4bb256bfe
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir
@@ -0,0 +1,192 @@
+// RUN: mlir-opt -spirv-replicated-const-composite -split-input-file -verify-diagnostics %s -o - | FileCheck %s
+
+spirv.module Logical GLSL450 {
+  spirv.func @splat_vector_of_i32() -> (vector<3xi32>) "None" {
+    %0 = spirv.Constant dense<2> : vector<3xi32>
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2 : i32] : vector<3xi32>
+    spirv.ReturnValue %0 : vector<3xi32>
+  }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+  spirv.func @splat_array_of_i32() -> (!spirv.array<3 x i32>) "None" {
+    %0 = spirv.Constant [1 : i32, 1 : i32, 1 : i32] : !spirv.array<3 x i32>
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [1 : i32] : !spirv.array<3 x i32>
+    spirv.ReturnValue %0 : !spirv.array<3 x i32>
+  }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+  spirv.func @splat_array_of_splat_array_of_i32() -> (!spirv.array<2 x !spirv.array<3 x i32>>) "None" {
+    // CHECK: spirv.EXT.ConstantCompositeReplicate [3 : i32] : !spirv.array<2 x !spirv.array<3 x i32>>
+    %0 = spirv.Constant [[3 : i32, 3 : i32, 3 : i32], [3 : i32, 3 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<3 x i32>>
+    spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x i32>>
+  }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+  spirv.func @splat_array_of_non_splat_array_of_i32() -> (!spirv.array<2 x !spirv.array<3 x i32>>) "None" {
+    // CHECK: spirv.EXT.ConstantCompositeReplicate {{\[}}[1 : i32, 2 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<3 x i32>>
+    %0 = spirv.Constant [[1 : i32, 2 : i32, 3 : i32], [1 : i32, 2 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<3 x i32>>
+    spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x i32>>
+  }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+  spirv.func @splat_array_of_vectors_of_i32() -> (!spirv.array<2xvector<2xi32>>) "None" {
+    %0 = spirv.Constant [dense<[1, 2]> : vector<2xi32>, dense<[1, 2]> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>>
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<[1, 2]> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>>
+    spirv.ReturnValue %0 : !spirv.array<2 x vector<2xi32>>
+  }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+  spirv.func @splat_array_of_splat_vectors_of_i32() -> (!spirv.array<2 x vector<2xi32>>) "None" {
+    %0 = spirv.Constant [dense<2> : vector<2xi32>, dense<2> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>>
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2 : i32] : !spirv.array<2 x vector<2xi32>>
+    spirv.ReturnValue %0 : !spirv.array<2 x vector<2xi32>>
+  }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+  spirv.func @splat_tensor_of_i32() -> (!spirv.array<2 x !spirv.array<3 x i32>>) "None" {
+    // CHECK: spirv.EXT.ConstantCompositeReplicate [3 : i32] : !spirv.array<2 x !spirv.array<3 x i32>>
+    %0 = spirv.Constant dense<3> : tensor<2x3xi32> : !spirv.array<2 x !spirv.array<3 x i32>>
+    spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x i32>>
+  }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+  spirv.func @splat_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
+    // CHECK: spirv.EXT.ConstantCompositeReplicate [2 : i32] : !spirv.arm.tensor<2x3xi32>
+    %0 = spirv.Constant dense<2> : !spirv.arm.tensor<2x3xi32>
+    spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
+  }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+  spirv.func @splat_vector_of_f32() -> (vector<3xf32>) "None" {
+    %0 = spirv.Constant dense<2.0> : vector<3xf32>
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2.000000e+00 : f32] : vector<3xf32>
+    spirv.ReturnValue %0 : vector<3xf32>
+  }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+  spirv.func @splat_array_of_f32() -> (!spirv.array<3 x f32>) "None" {
+    %0 = spirv.Constant [1.0 : f32, 1.0 : f32, 1.0 : f32] : !spirv.array<3 x f32>
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [1.000000e+00 : f32] : !spirv.array<3 x f32>
+    spirv.ReturnValue %0 : !spirv.array<3 x f32>
+  }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+  spirv.func @splat_array_of_splat_array_of_f32() -> (!spirv.array<2 x !spirv.array<3 x f32>>) "None" {
+    // CHECK: spirv.EXT.ConstantCompositeReplicate [3.000000e+00 : f32] : !spirv.array<2 x !spirv.array<3 x f32>>
+    %0 = spirv.Constant [[3.0 : f32, 3.0 : f32, 3.0 : f32], [3.0 : f32, 3.0 : f32, 3.0 : f32]] : !spirv.array<2 x !spirv.array<3 x f32>>
+    spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x f32>>
+  }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+  spirv.func @splat_array_of_non_splat_array_of_f32() -> (!spirv.array<2 x !spirv.array<3 x f32>>) "None" {
+    // CHECK: spirv.EXT.ConstantCompositeReplicate {{\[}}[1.000000e+00 : f32, 2.000000e+00 : f32, 3.000000e+00 : f32]] : !spirv.array<2 x !spirv.array<3 x f32>>
+    %0 = spirv.Constant [[1.0 : f32, 2.0 : f32, 3.0 : f32], [1.0 : f32, 2.0 : f32, 3.0 : f32]] : !spirv.array<2 x !spirv.array<3 x f32>>
+    spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x f32>>
+  }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+  spirv.func @splat_array_of_vectors_of_f32() -> (!spirv.array<2xvector<2xf32>>) "None" {
+    %0 = spirv.Constant [dense<[1.0, 2.0]> : vector<2xf32>, dense<[1.0, 2.0]> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>>
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<[1.000000e+00, 2.000000e+00]> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>>
+    spirv.ReturnValue %0 : !spirv.array<2 x vector<2xf32>>
+  }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+  spirv.func @splat_array_of_splat_vectors_of_f32() -> (!spirv.array<2 x vector<2xf32>>) "None" {
+    %0 = spirv.Constant [dense<2.0> : vector<2xf32>, dense<2.0> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>>
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2.000000e+00 : f32] : !spirv.array<2 x vector<2xf32>>
+    spirv.ReturnValue %0 : !spirv.array<2 x vector<2xf32>>
+  }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+  spirv.func @splat_tensor_of_f32() -> (!spirv.array<2 x !spirv.array<3 x f32>>) "None" {
+    // CHECK: spirv.EXT.ConstantCompositeReplicate [3.000000e+00 : f32] : !spirv.array<2 x !spirv.array<3 x f32>>
+    %0 = spirv.Constant dense<3.0> : tensor<2x3xf32> : !spirv.array<2 x !spirv.array<3 x f32>>
+    spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x f32>>
+  }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+  spirv.func @splat_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
+    // CHECK: spirv.EXT.ConstantCompositeReplicate [2.000000e+00 : f32] : !spirv.arm.tensor<2x3xf32>
+    %0 = spirv.Constant dense<2.0> : !spirv.arm.tensor<2x3xf32>
+    spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
+  }
+}
+
+// -----
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompositesEXT], [SPV_EXT_replicated_composites]> {
+
+  spirv.SpecConstant @sc_i32_1 = 1 : i32
+
+  // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_array_of_i32 (@sc_i32_1) : !spirv.array<3 x i32>
+  spirv.SpecConstantComposite @scc_splat_array_of_i32 (@sc_i32_1, @sc_i32_1, @sc_i32_1) : !spirv.array<3 x i32>
+
+  // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_struct_of_i32 (@sc_i32_1) : !spirv.struct<(i32, i32, i32)>
+  spirv.SpecConstantComposite @scc_splat_struct_of_i32 (@sc_i32_1, @sc_i32_1, @sc_i32_1) : !spirv.struct<(i32, i32, i32)>
+
+  // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_vector_of_i32 (@sc_i32_1) : vector<3xi32>
+  spirv.SpecConstantComposite @scc_splat_vector_of_i32 (@sc_i32_1, @sc_i32_1, @sc_i32_1) : vector<3 x i32>
+
+  // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_arm_tensor_of_i32 (@sc_i32_1) : !spirv.arm.tensor<3xi32>
+  spirv.SpecConstantComposite @scc_splat_arm_tensor_of_i32 (@sc_i32_1, @sc_i32_1, @sc_i32_1) : !spirv.arm.tensor<3xi32>
+
+  spirv.SpecConstant @sc_f32_1 = 1.0 : f32
+
+  // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_array_of_f32 (@sc_f32_1) : !spirv.array<3 x f32>
+  spirv.SpecConstantComposite @scc_splat_array_of_f32 (@sc_f32_1, @sc_f32_1, @sc_f32_1) : !spirv.array<3 x f32>
+
+  // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_struct_of_f32 (@sc_f32_1) : !spirv.struct<(f32, f32, f32)>
+  spirv.SpecConstantComposite @scc_splat_struct_of_f32 (@sc_f32_1, @sc_f32_1, @sc_f32_1) : !spirv.struct<(f32, f32, f32)>
+
+  // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_vector_of_f32 (@sc_f32_1) : vector<3xf32>
+  spirv.SpecConstantComposite @scc_splat_vector_of_f32 (@sc_f32_1, @sc_f32_1, @sc_f32_1) : vector<3 x f32>
+
+  // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_arm_tensor_of_f32 (@sc_f32_1) : !spirv.arm.tensor<3xf32>
+  spirv.SpecConstantComposite @scc_splat_arm_tensor_of_f32 (@sc_f32_1, @sc_f32_1, @sc_f32_1) : !spirv.arm.tensor<3xf32>
+}
\ No newline at end of file

>From 14b70ce5742cacc84b8d2833e62eface967c07c8 Mon Sep 17 00:00:00 2001
From: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian at arm.com>
Date: Wed, 16 Jul 2025 08:20:03 +0100
Subject: [PATCH 2/9] Minor bug fix and revision

Signed-off-by: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian at arm.com>
---
 .../Transforms/ConversionToReplicatedConstantCompositePass.cpp  | 2 +-
 .../Dialect/SPIRV/Transforms/replicated-const-composites.mlir   | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/SPIRV/Transforms/ConversionToReplicatedConstantCompositePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/ConversionToReplicatedConstantCompositePass.cpp
index 530a0f4aa67f5..590fa6e9d684a 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/ConversionToReplicatedConstantCompositePass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/ConversionToReplicatedConstantCompositePass.cpp
@@ -28,7 +28,7 @@ using namespace mlir;
 
 namespace {
 
-Attribute getSplatAttribute(Attribute valueAttr, uint32_t splatCount) {
+Attribute getSplatAttribute(Attribute valueAttr, uint32_t &splatCount) {
   Attribute attr;
   if (auto denseAttr = dyn_cast<DenseElementsAttr>(valueAttr)) {
     if (denseAttr.isSplat()) {
diff --git a/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir b/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir
index f8cd4bb256bfe..a7a9ca25edc6a 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir
@@ -160,7 +160,7 @@ spirv.module Logical GLSL450 {
 
 // -----
 
-spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompositesEXT], [SPV_EXT_replicated_composites]> {
+spirv.module Logical GLSL450 {
 
   spirv.SpecConstant @sc_i32_1 = 1 : i32
 

>From 51ce7ac82a887a683010b54c31c95ec7d5e9bd54 Mon Sep 17 00:00:00 2001
From: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian at arm.com>
Date: Wed, 16 Jul 2025 15:31:21 +0100
Subject: [PATCH 3/9] Addressing code review comments

Signed-off-by: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian at arm.com>
---
 .../mlir/Dialect/SPIRV/Transforms/Passes.td   |   2 +-
 ...rsionToReplicatedConstantCompositePass.cpp |  10 +-
 .../replicated-const-composites.mlir          | 145 ++++++++++++++++--
 3 files changed, 135 insertions(+), 22 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td b/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td
index 3c04db8396367..bc1c1f075e09b 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td
@@ -78,7 +78,7 @@ def SPIRVWebGPUPreparePass : Pass<"spirv-webgpu-prepare", "spirv::ModuleOp"> {
 }
 
 def SPIRVReplicatedConstantCompositePass
-    : Pass<"spirv-replicated-const-composite", "spirv::ModuleOp"> {
+    : Pass<"spirv-convert-to-replicated-const-composite", "spirv::ModuleOp"> {
   let summary = "Convert splat composite constants and spec constants to"
                 "corresponding replicated constant composite ops defined by"
                 "SPV_EXT_replicated_composites";
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/ConversionToReplicatedConstantCompositePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/ConversionToReplicatedConstantCompositePass.cpp
index 590fa6e9d684a..da1777ab42f97 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/ConversionToReplicatedConstantCompositePass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/ConversionToReplicatedConstantCompositePass.cpp
@@ -1,5 +1,4 @@
-//===- ConversionToReplicatedConstantCompositePass.cpp
-//---------------------------===//
+//===- ConversionToReplicatedConstantCompositePass.cpp --------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -83,7 +82,7 @@ class ConstantOpConversion : public OpRewritePattern<spirv::ConstantOp> {
 
     if (splatCount == 1)
       return rewriter.notifyMatchFailure(op,
-                                         "composite has only one consituent");
+                                         "composite has only one constituent");
 
     rewriter.replaceOpWithNewOp<spirv::EXTConstantCompositeReplicateOp>(
         op, op.getType(), splatAttr);
@@ -102,7 +101,7 @@ class SpecConstantCompositeOpConversion
     if (!compositeType)
       return rewriter.notifyMatchFailure(op, "not a composite constant");
 
-    auto constituents = op.getConstituents();
+    ArrayAttr constituents = op.getConstituents();
     if (constituents.size() == 1)
       return rewriter.notifyMatchFailure(op,
                                          "composite has only one consituent");
@@ -113,6 +112,9 @@ class SpecConstantCompositeOpConversion
 
     auto splatConstituent =
         dyn_cast<FlatSymbolRefAttr>(op.getConstituents()[0]);
+    if (!splatConstituent)
+      return rewriter.notifyMatchFailure(
+          op, "expected flat symbol reference for splat constituent");
 
     rewriter.replaceOpWithNewOp<spirv::EXTSpecConstantCompositeReplicateOp>(
         op, TypeAttr::get(op.getType()), op.getSymNameAttr(), splatConstituent);
diff --git a/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir b/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir
index a7a9ca25edc6a..4431f417635b8 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir
@@ -1,9 +1,9 @@
-// RUN: mlir-opt -spirv-replicated-const-composite -split-input-file -verify-diagnostics %s -o - | FileCheck %s
+// RUN: mlir-opt --spirv-convert-to-replicated-const-composite --split-input-file --verify-diagnostics %s -o - | FileCheck %s
 
 spirv.module Logical GLSL450 {
   spirv.func @splat_vector_of_i32() -> (vector<3xi32>) "None" {
-    %0 = spirv.Constant dense<2> : vector<3xi32>
     // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2 : i32] : vector<3xi32>
+    %0 = spirv.Constant dense<2> : vector<3xi32>
     spirv.ReturnValue %0 : vector<3xi32>
   }
 }
@@ -12,8 +12,8 @@ spirv.module Logical GLSL450 {
 
 spirv.module Logical GLSL450 {
   spirv.func @splat_array_of_i32() -> (!spirv.array<3 x i32>) "None" {
-    %0 = spirv.Constant [1 : i32, 1 : i32, 1 : i32] : !spirv.array<3 x i32>
     // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [1 : i32] : !spirv.array<3 x i32>
+    %0 = spirv.Constant [1 : i32, 1 : i32, 1 : i32] : !spirv.array<3 x i32>
     spirv.ReturnValue %0 : !spirv.array<3 x i32>
   }
 }
@@ -22,7 +22,7 @@ spirv.module Logical GLSL450 {
 
 spirv.module Logical GLSL450 {
   spirv.func @splat_array_of_splat_array_of_i32() -> (!spirv.array<2 x !spirv.array<3 x i32>>) "None" {
-    // CHECK: spirv.EXT.ConstantCompositeReplicate [3 : i32] : !spirv.array<2 x !spirv.array<3 x i32>>
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [3 : i32] : !spirv.array<2 x !spirv.array<3 x i32>>
     %0 = spirv.Constant [[3 : i32, 3 : i32, 3 : i32], [3 : i32, 3 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<3 x i32>>
     spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x i32>>
   }
@@ -32,7 +32,7 @@ spirv.module Logical GLSL450 {
 
 spirv.module Logical GLSL450 {
   spirv.func @splat_array_of_non_splat_array_of_i32() -> (!spirv.array<2 x !spirv.array<3 x i32>>) "None" {
-    // CHECK: spirv.EXT.ConstantCompositeReplicate {{\[}}[1 : i32, 2 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<3 x i32>>
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate {{\[}}[1 : i32, 2 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<3 x i32>>
     %0 = spirv.Constant [[1 : i32, 2 : i32, 3 : i32], [1 : i32, 2 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<3 x i32>>
     spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x i32>>
   }
@@ -42,8 +42,8 @@ spirv.module Logical GLSL450 {
 
 spirv.module Logical GLSL450 {
   spirv.func @splat_array_of_vectors_of_i32() -> (!spirv.array<2xvector<2xi32>>) "None" {
-    %0 = spirv.Constant [dense<[1, 2]> : vector<2xi32>, dense<[1, 2]> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>>
     // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<[1, 2]> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>>
+    %0 = spirv.Constant [dense<[1, 2]> : vector<2xi32>, dense<[1, 2]> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>>
     spirv.ReturnValue %0 : !spirv.array<2 x vector<2xi32>>
   }
 }
@@ -52,8 +52,8 @@ spirv.module Logical GLSL450 {
 
 spirv.module Logical GLSL450 {
   spirv.func @splat_array_of_splat_vectors_of_i32() -> (!spirv.array<2 x vector<2xi32>>) "None" {
-    %0 = spirv.Constant [dense<2> : vector<2xi32>, dense<2> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>>
     // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2 : i32] : !spirv.array<2 x vector<2xi32>>
+    %0 = spirv.Constant [dense<2> : vector<2xi32>, dense<2> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>>
     spirv.ReturnValue %0 : !spirv.array<2 x vector<2xi32>>
   }
 }
@@ -62,7 +62,7 @@ spirv.module Logical GLSL450 {
 
 spirv.module Logical GLSL450 {
   spirv.func @splat_tensor_of_i32() -> (!spirv.array<2 x !spirv.array<3 x i32>>) "None" {
-    // CHECK: spirv.EXT.ConstantCompositeReplicate [3 : i32] : !spirv.array<2 x !spirv.array<3 x i32>>
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [3 : i32] : !spirv.array<2 x !spirv.array<3 x i32>>
     %0 = spirv.Constant dense<3> : tensor<2x3xi32> : !spirv.array<2 x !spirv.array<3 x i32>>
     spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x i32>>
   }
@@ -72,7 +72,7 @@ spirv.module Logical GLSL450 {
 
 spirv.module Logical GLSL450 {
   spirv.func @splat_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
-    // CHECK: spirv.EXT.ConstantCompositeReplicate [2 : i32] : !spirv.arm.tensor<2x3xi32>
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2 : i32] : !spirv.arm.tensor<2x3xi32>
     %0 = spirv.Constant dense<2> : !spirv.arm.tensor<2x3xi32>
     spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
   }
@@ -82,8 +82,8 @@ spirv.module Logical GLSL450 {
 
 spirv.module Logical GLSL450 {
   spirv.func @splat_vector_of_f32() -> (vector<3xf32>) "None" {
-    %0 = spirv.Constant dense<2.0> : vector<3xf32>
     // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2.000000e+00 : f32] : vector<3xf32>
+    %0 = spirv.Constant dense<2.0> : vector<3xf32>
     spirv.ReturnValue %0 : vector<3xf32>
   }
 }
@@ -92,8 +92,8 @@ spirv.module Logical GLSL450 {
 
 spirv.module Logical GLSL450 {
   spirv.func @splat_array_of_f32() -> (!spirv.array<3 x f32>) "None" {
-    %0 = spirv.Constant [1.0 : f32, 1.0 : f32, 1.0 : f32] : !spirv.array<3 x f32>
     // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [1.000000e+00 : f32] : !spirv.array<3 x f32>
+    %0 = spirv.Constant [1.0 : f32, 1.0 : f32, 1.0 : f32] : !spirv.array<3 x f32>
     spirv.ReturnValue %0 : !spirv.array<3 x f32>
   }
 }
@@ -102,7 +102,7 @@ spirv.module Logical GLSL450 {
 
 spirv.module Logical GLSL450 {
   spirv.func @splat_array_of_splat_array_of_f32() -> (!spirv.array<2 x !spirv.array<3 x f32>>) "None" {
-    // CHECK: spirv.EXT.ConstantCompositeReplicate [3.000000e+00 : f32] : !spirv.array<2 x !spirv.array<3 x f32>>
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [3.000000e+00 : f32] : !spirv.array<2 x !spirv.array<3 x f32>>
     %0 = spirv.Constant [[3.0 : f32, 3.0 : f32, 3.0 : f32], [3.0 : f32, 3.0 : f32, 3.0 : f32]] : !spirv.array<2 x !spirv.array<3 x f32>>
     spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x f32>>
   }
@@ -112,7 +112,7 @@ spirv.module Logical GLSL450 {
 
 spirv.module Logical GLSL450 {
   spirv.func @splat_array_of_non_splat_array_of_f32() -> (!spirv.array<2 x !spirv.array<3 x f32>>) "None" {
-    // CHECK: spirv.EXT.ConstantCompositeReplicate {{\[}}[1.000000e+00 : f32, 2.000000e+00 : f32, 3.000000e+00 : f32]] : !spirv.array<2 x !spirv.array<3 x f32>>
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate {{\[}}[1.000000e+00 : f32, 2.000000e+00 : f32, 3.000000e+00 : f32]] : !spirv.array<2 x !spirv.array<3 x f32>>
     %0 = spirv.Constant [[1.0 : f32, 2.0 : f32, 3.0 : f32], [1.0 : f32, 2.0 : f32, 3.0 : f32]] : !spirv.array<2 x !spirv.array<3 x f32>>
     spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x f32>>
   }
@@ -122,8 +122,8 @@ spirv.module Logical GLSL450 {
 
 spirv.module Logical GLSL450 {
   spirv.func @splat_array_of_vectors_of_f32() -> (!spirv.array<2xvector<2xf32>>) "None" {
-    %0 = spirv.Constant [dense<[1.0, 2.0]> : vector<2xf32>, dense<[1.0, 2.0]> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>>
     // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<[1.000000e+00, 2.000000e+00]> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>>
+    %0 = spirv.Constant [dense<[1.0, 2.0]> : vector<2xf32>, dense<[1.0, 2.0]> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>>
     spirv.ReturnValue %0 : !spirv.array<2 x vector<2xf32>>
   }
 }
@@ -132,8 +132,8 @@ spirv.module Logical GLSL450 {
 
 spirv.module Logical GLSL450 {
   spirv.func @splat_array_of_splat_vectors_of_f32() -> (!spirv.array<2 x vector<2xf32>>) "None" {
-    %0 = spirv.Constant [dense<2.0> : vector<2xf32>, dense<2.0> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>>
     // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2.000000e+00 : f32] : !spirv.array<2 x vector<2xf32>>
+    %0 = spirv.Constant [dense<2.0> : vector<2xf32>, dense<2.0> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>>
     spirv.ReturnValue %0 : !spirv.array<2 x vector<2xf32>>
   }
 }
@@ -142,7 +142,7 @@ spirv.module Logical GLSL450 {
 
 spirv.module Logical GLSL450 {
   spirv.func @splat_tensor_of_f32() -> (!spirv.array<2 x !spirv.array<3 x f32>>) "None" {
-    // CHECK: spirv.EXT.ConstantCompositeReplicate [3.000000e+00 : f32] : !spirv.array<2 x !spirv.array<3 x f32>>
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [3.000000e+00 : f32] : !spirv.array<2 x !spirv.array<3 x f32>>
     %0 = spirv.Constant dense<3.0> : tensor<2x3xf32> : !spirv.array<2 x !spirv.array<3 x f32>>
     spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x f32>>
   }
@@ -152,7 +152,7 @@ spirv.module Logical GLSL450 {
 
 spirv.module Logical GLSL450 {
   spirv.func @splat_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
-    // CHECK: spirv.EXT.ConstantCompositeReplicate [2.000000e+00 : f32] : !spirv.arm.tensor<2x3xf32>
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2.000000e+00 : f32] : !spirv.arm.tensor<2x3xf32>
     %0 = spirv.Constant dense<2.0> : !spirv.arm.tensor<2x3xf32>
     spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
   }
@@ -160,6 +160,86 @@ spirv.module Logical GLSL450 {
 
 // -----
 
+spirv.module Logical GLSL450 {
+  spirv.func @array_of_one_i32() -> (!spirv.array<1 x i32>) "None" {
+    // CHECK-NOT: spirv.EXT.ConstantCompositeReplicate
+    %0 = spirv.Constant [1 : i32] : !spirv.array<1 x i32>
+    spirv.ReturnValue %0 : !spirv.array<1 x i32>
+  }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+  spirv.func @arm_tensor_of_one_i32() -> (!spirv.arm.tensor<1xi32>) "None" {
+    // CHECK-NOT: spirv.EXT.ConstantCompositeReplicate
+    %0 = spirv.Constant dense<1> : !spirv.arm.tensor<1xi32>
+    spirv.ReturnValue %0 : !spirv.arm.tensor<1xi32>
+  }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+  spirv.func @non_splat_vector_of_i32() -> (vector<3xi32>) "None" {
+    // CHECK-NOT: spirv.EXT.ConstantCompositeReplicate
+    %0 = spirv.Constant dense<[0, 1, 2]> : vector<3xi32>
+    spirv.ReturnValue %0 : vector<3xi32>
+  }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+  spirv.func @non_splat_array_of_vectors_of_i32() -> (!spirv.array<2xvector<2xi32>>) "None" {
+    // CHECK-NOT: spirv.EXT.ConstantCompositeReplicate
+    %0 = spirv.Constant [dense<[1, 2]> : vector<2xi32>, dense<[1, 3]> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>>
+    spirv.ReturnValue %0 : !spirv.array<2 x vector<2xi32>>
+  }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+  spirv.func @array_of_one_f32() -> (!spirv.array<1 x f32>) "None" {
+    // CHECK-NOT: spirv.EXT.ConstantCompositeReplicate
+    %0 = spirv.Constant [1.0 : f32] : !spirv.array<1 x f32>
+    spirv.ReturnValue %0 : !spirv.array<1 x f32>
+  }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+  spirv.func @arm_tensor_of_one_f32() -> (!spirv.arm.tensor<1xf32>) "None" {
+    // CHECK-NOT: spirv.EXT.ConstantCompositeReplicate
+    %0 = spirv.Constant dense<1.0> : !spirv.arm.tensor<1xf32>
+    spirv.ReturnValue %0 : !spirv.arm.tensor<1xf32>
+  }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+  spirv.func @non_splat_vector_of_f32() -> (vector<3xf32>) "None" {
+    // CHECK-NOT: spirv.EXT.ConstantCompositeReplicate
+    %0 = spirv.Constant dense<[0.0, 1.0, 2.0]> : vector<3xf32>
+    spirv.ReturnValue %0 : vector<3xf32>
+  }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+  spirv.func @non_splat_array_of_vectors_of_f32() -> (!spirv.array<2xvector<2xf32>>) "None" {
+    // CHECK-NOT: spirv.EXT.ConstantCompositeReplicate
+    %0 = spirv.Constant [dense<[1.0, 2.0]> : vector<2xf32>, dense<[1.0, 3.0]> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>>
+    spirv.ReturnValue %0 : !spirv.array<2 x vector<2xf32>>
+  }
+}
+
+// -----
+
 spirv.module Logical GLSL450 {
 
   spirv.SpecConstant @sc_i32_1 = 1 : i32
@@ -189,4 +269,35 @@ spirv.module Logical GLSL450 {
 
   // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_arm_tensor_of_f32 (@sc_f32_1) : !spirv.arm.tensor<3xf32>
   spirv.SpecConstantComposite @scc_splat_arm_tensor_of_f32 (@sc_f32_1, @sc_f32_1, @sc_f32_1) : !spirv.arm.tensor<3xf32>
+
+  spirv.SpecConstant @sc_i32_2 = 2 : i32
+
+  // CHECK-NOT: spirv.EXT.SpecConstantCompositeReplicate
+  spirv.SpecConstantComposite @scc_array_of_one_i32 (@sc_i32_1) : !spirv.array<1 x i32>
+
+  // CHECK-NOT: spirv.EXT.SpecConstantCompositeReplicate
+  spirv.SpecConstantComposite @scc_arm_tensor_of_one_i32 (@sc_i32_1) : !spirv.arm.tensor<1xi32>
+
+  // CHECK-NOT: spirv.EXT.SpecConstantCompositeReplicate
+  spirv.SpecConstantComposite @scc_non_splat_vector_of_i32 (@sc_i32_1, @sc_i32_1, @sc_i32_2) : vector<3 x i32>
+
+  // CHECK-NOT: spirv.EXT.SpecConstantCompositeReplicate
+  spirv.SpecConstantComposite @scc_non_splat_arm_tensor_of_i32 (@sc_i32_2, @sc_i32_1, @sc_i32_1) : !spirv.arm.tensor<3xi32>
+
+  spirv.SpecConstant @sc_f32_2 = 2.0 : f32
+
+  // CHECK-NOT: spirv.EXT.SpecConstantCompositeReplicate
+  spirv.SpecConstantComposite @scc_array_of_one_f32 (@sc_f32_1) : !spirv.array<1 x f32>
+
+  // CHECK-NOT: spirv.EXT.SpecConstantCompositeReplicate
+  spirv.SpecConstantComposite @scc_arm_tensor_of_one_f32 (@sc_f32_1) : !spirv.arm.tensor<1xf32>
+
+  // CHECK-NOT: spirv.EXT.SpecConstantCompositeReplicate
+  spirv.SpecConstantComposite @scc_non_splat_vector_of_f32 (@sc_f32_1, @sc_f32_1, @sc_f32_2) : vector<3 x f32>
+
+  // CHECK-NOT: spirv.EXT.SpecConstantCompositeReplicate
+  spirv.SpecConstantComposite @scc_non_splat_arm_tensor_of_f32 (@sc_f32_2, @sc_f32_1, @sc_f32_1) : !spirv.arm.tensor<3xf32>
+
+  // CHECK-NOT: spirv.EXT.SpecConstantCompositeReplicate
+  spirv.SpecConstantComposite @scc_struct_of_i32_and_f32 (@sc_i32_1, @sc_i32_1, @sc_f32_1) : !spirv.struct<(i32, i32, f32)>
 }
\ No newline at end of file

>From d065a5adb971cf7eca74d92d155647971f79c47f Mon Sep 17 00:00:00 2001
From: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian at arm.com>
Date: Thu, 17 Jul 2025 09:44:18 +0100
Subject: [PATCH 4/9] Addressing further code review comments

Signed-off-by: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian at arm.com>
---
 .../mlir/Dialect/SPIRV/Transforms/Passes.td   |  4 +-
 .../Dialect/SPIRV/Transforms/CMakeLists.txt   |  4 +-
 ...vertToReplicatedConstantCompositePass.cpp} | 86 +++++++++----------
 .../replicated-const-composites.mlir          | 86 +------------------
 4 files changed, 46 insertions(+), 134 deletions(-)
 rename mlir/lib/Dialect/SPIRV/Transforms/{ConversionToReplicatedConstantCompositePass.cpp => ConvertToReplicatedConstantCompositePass.cpp} (65%)

diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td b/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td
index bc1c1f075e09b..a4418085b5ce5 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td
@@ -79,8 +79,8 @@ def SPIRVWebGPUPreparePass : Pass<"spirv-webgpu-prepare", "spirv::ModuleOp"> {
 
 def SPIRVReplicatedConstantCompositePass
     : Pass<"spirv-convert-to-replicated-const-composite", "spirv::ModuleOp"> {
-  let summary = "Convert splat composite constants and spec constants to"
-                "corresponding replicated constant composite ops defined by"
+  let summary = "Convert splat composite constants and spec constants to "
+                "corresponding replicated constant composite ops defined by "
                 "SPV_EXT_replicated_composites";
 }
 
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
index c675af9d048cc..b947447dad46a 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
@@ -1,6 +1,6 @@
 set(LLVM_OPTIONAL_SOURCES
   CanonicalizeGLPass.cpp
-  ConversionToReplicatedConstantCompositePass.cpp
+  ConvertToReplicatedConstantCompositePass.cpp
   DecorateCompositeTypeLayoutPass.cpp
   LowerABIAttributesPass.cpp
   RewriteInsertsPass.cpp
@@ -31,7 +31,7 @@ add_mlir_dialect_library(MLIRSPIRVConversion
 
 add_mlir_dialect_library(MLIRSPIRVTransforms
   CanonicalizeGLPass.cpp
-  ConversionToReplicatedConstantCompositePass.cpp
+  ConvertToReplicatedConstantCompositePass.cpp
   DecorateCompositeTypeLayoutPass.cpp
   LowerABIAttributesPass.cpp
   RewriteInsertsPass.cpp
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/ConversionToReplicatedConstantCompositePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp
similarity index 65%
rename from mlir/lib/Dialect/SPIRV/Transforms/ConversionToReplicatedConstantCompositePass.cpp
rename to mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp
index da1777ab42f97..acd66002746aa 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/ConversionToReplicatedConstantCompositePass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp
@@ -1,4 +1,4 @@
-//===- ConversionToReplicatedConstantCompositePass.cpp --------------------===//
+//===- ConvertToReplicatedConstantCompositePass.cpp --------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -14,21 +14,18 @@
 
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/Transforms/Passes.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
 
-namespace mlir {
-namespace spirv {
+namespace mlir::spirv {
 #define GEN_PASS_DEF_SPIRVREPLICATEDCONSTANTCOMPOSITEPASS
 #include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
-} // namespace spirv
-} // namespace mlir
-
-using namespace mlir;
 
 namespace {
 
-Attribute getSplatAttribute(Attribute valueAttr, uint32_t &splatCount) {
+static std::pair<Attribute, uint32_t>
+getSplatAttributeAndCount(Attribute valueAttr) {
   Attribute attr;
+  uint32_t splatCount = 0;
   if (auto denseAttr = dyn_cast<DenseElementsAttr>(valueAttr)) {
     if (denseAttr.isSplat()) {
       attr = denseAttr.getSplatValue<Attribute>();
@@ -44,30 +41,27 @@ Attribute getSplatAttribute(Attribute valueAttr, uint32_t &splatCount) {
 
   if (attr) {
     if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
-      if (isa<spirv::CompositeType>(typedAttr.getType()))
-        if (Attribute newAttr = getSplatAttribute(attr, splatCount))
-          attr = newAttr;
+      if (isa<spirv::CompositeType>(typedAttr.getType())) {
+        std::pair<Attribute, uint32_t> newSplatAttrAndCount =
+            getSplatAttributeAndCount(attr);
+        if (newSplatAttrAndCount.first) {
+          return newSplatAttrAndCount;
+        }
+      }
     } else if (isa<ArrayAttr>(attr)) {
-      if (Attribute newAttr = getSplatAttribute(attr, splatCount))
-        attr = newAttr;
+      std::pair<Attribute, uint32_t> newSplatAttrAndCount =
+          getSplatAttributeAndCount(attr);
+      if (newSplatAttrAndCount.first) {
+        return newSplatAttrAndCount;
+      }
     }
   }
 
-  return attr;
+  return {attr, splatCount};
 }
 
-} // namespace
-
-namespace {
-class ConversionToReplicatedConstantCompositePass
-    : public spirv::impl::SPIRVReplicatedConstantCompositePassBase<
-          ConversionToReplicatedConstantCompositePass> {
-public:
-  void runOnOperation() override;
-};
-
-class ConstantOpConversion : public OpRewritePattern<spirv::ConstantOp> {
-  using OpRewritePattern<spirv::ConstantOp>::OpRewritePattern;
+struct ConstantOpConversion final : OpRewritePattern<spirv::ConstantOp> {
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(spirv::ConstantOp op,
                                 PatternRewriter &rewriter) const override {
@@ -75,25 +69,25 @@ class ConstantOpConversion : public OpRewritePattern<spirv::ConstantOp> {
     if (!compositeType)
       return rewriter.notifyMatchFailure(op, "not a composite constant");
 
-    uint32_t splatCount = 0;
-    Attribute splatAttr = getSplatAttribute(op.getValue(), splatCount);
-    if (!splatAttr)
+    std::pair<Attribute, uint32_t> splatAttrAndCount =
+        getSplatAttributeAndCount(op.getValue());
+    if (!splatAttrAndCount.first)
       return rewriter.notifyMatchFailure(op, "composite is not splat");
 
-    if (splatCount == 1)
+    if (splatAttrAndCount.second == 1)
       return rewriter.notifyMatchFailure(op,
                                          "composite has only one constituent");
 
     rewriter.replaceOpWithNewOp<spirv::EXTConstantCompositeReplicateOp>(
-        op, op.getType(), splatAttr);
+        op, op.getType(), splatAttrAndCount.first);
 
     return success();
   }
 };
 
-class SpecConstantCompositeOpConversion
-    : public OpRewritePattern<spirv::SpecConstantCompositeOp> {
-  using OpRewritePattern<spirv::SpecConstantCompositeOp>::OpRewritePattern;
+struct SpecConstantCompositeOpConversion final
+    : OpRewritePattern<spirv::SpecConstantCompositeOp> {
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(spirv::SpecConstantCompositeOp op,
                                 PatternRewriter &rewriter) const override {
@@ -123,15 +117,17 @@ class SpecConstantCompositeOpConversion
   }
 };
 
-void ConversionToReplicatedConstantCompositePass::runOnOperation() {
-  MLIRContext *context = &getContext();
-  RewritePatternSet patterns(context);
-  patterns.add<ConstantOpConversion>(context);
-  patterns.add<SpecConstantCompositeOpConversion>(context);
-
-  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
-    signalPassFailure();
+struct ConvertToReplicatedConstantCompositePass final
+    : spirv::impl::SPIRVReplicatedConstantCompositePassBase<
+          ConvertToReplicatedConstantCompositePass> {
+  void runOnOperation() override {
+    MLIRContext *context = &getContext();
+    RewritePatternSet patterns(context);
+    patterns.add<ConstantOpConversion, SpecConstantCompositeOpConversion>(
+        context);
+    walkAndApplyPatterns(getOperation(), std::move(patterns));
   }
-}
+};
 
-} // namespace
\ No newline at end of file
+} // namespace
+} // namespace mlir::spirv
diff --git a/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir b/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir
index 4431f417635b8..b343d0bc73b4f 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --spirv-convert-to-replicated-const-composite --split-input-file --verify-diagnostics %s -o - | FileCheck %s
+// RUN: mlir-opt --spirv-convert-to-replicated-const-composite --split-input-file --verify-diagnostics %s | FileCheck %s
 
 spirv.module Logical GLSL450 {
   spirv.func @splat_vector_of_i32() -> (vector<3xi32>) "None" {
@@ -6,211 +6,127 @@ spirv.module Logical GLSL450 {
     %0 = spirv.Constant dense<2> : vector<3xi32>
     spirv.ReturnValue %0 : vector<3xi32>
   }
-}
-
-// -----
 
-spirv.module Logical GLSL450 {
   spirv.func @splat_array_of_i32() -> (!spirv.array<3 x i32>) "None" {
     // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [1 : i32] : !spirv.array<3 x i32>
     %0 = spirv.Constant [1 : i32, 1 : i32, 1 : i32] : !spirv.array<3 x i32>
     spirv.ReturnValue %0 : !spirv.array<3 x i32>
   }
-}
 
-// -----
-
-spirv.module Logical GLSL450 {
   spirv.func @splat_array_of_splat_array_of_i32() -> (!spirv.array<2 x !spirv.array<3 x i32>>) "None" {
     // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [3 : i32] : !spirv.array<2 x !spirv.array<3 x i32>>
     %0 = spirv.Constant [[3 : i32, 3 : i32, 3 : i32], [3 : i32, 3 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<3 x i32>>
     spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x i32>>
   }
-}
-
-// -----
 
-spirv.module Logical GLSL450 {
   spirv.func @splat_array_of_non_splat_array_of_i32() -> (!spirv.array<2 x !spirv.array<3 x i32>>) "None" {
     // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate {{\[}}[1 : i32, 2 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<3 x i32>>
     %0 = spirv.Constant [[1 : i32, 2 : i32, 3 : i32], [1 : i32, 2 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<3 x i32>>
     spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x i32>>
   }
-}
 
-// -----
-
-spirv.module Logical GLSL450 {
   spirv.func @splat_array_of_vectors_of_i32() -> (!spirv.array<2xvector<2xi32>>) "None" {
     // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<[1, 2]> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>>
     %0 = spirv.Constant [dense<[1, 2]> : vector<2xi32>, dense<[1, 2]> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>>
     spirv.ReturnValue %0 : !spirv.array<2 x vector<2xi32>>
   }
-}
-
-// -----
 
-spirv.module Logical GLSL450 {
   spirv.func @splat_array_of_splat_vectors_of_i32() -> (!spirv.array<2 x vector<2xi32>>) "None" {
     // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2 : i32] : !spirv.array<2 x vector<2xi32>>
     %0 = spirv.Constant [dense<2> : vector<2xi32>, dense<2> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>>
     spirv.ReturnValue %0 : !spirv.array<2 x vector<2xi32>>
   }
-}
 
-// -----
-
-spirv.module Logical GLSL450 {
   spirv.func @splat_tensor_of_i32() -> (!spirv.array<2 x !spirv.array<3 x i32>>) "None" {
     // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [3 : i32] : !spirv.array<2 x !spirv.array<3 x i32>>
     %0 = spirv.Constant dense<3> : tensor<2x3xi32> : !spirv.array<2 x !spirv.array<3 x i32>>
     spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x i32>>
   }
-}
-
-// -----
 
-spirv.module Logical GLSL450 {
   spirv.func @splat_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
     // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2 : i32] : !spirv.arm.tensor<2x3xi32>
     %0 = spirv.Constant dense<2> : !spirv.arm.tensor<2x3xi32>
     spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
   }
-}
 
-// -----
-
-spirv.module Logical GLSL450 {
   spirv.func @splat_vector_of_f32() -> (vector<3xf32>) "None" {
     // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2.000000e+00 : f32] : vector<3xf32>
     %0 = spirv.Constant dense<2.0> : vector<3xf32>
     spirv.ReturnValue %0 : vector<3xf32>
   }
-}
-
-// -----
 
-spirv.module Logical GLSL450 {
   spirv.func @splat_array_of_f32() -> (!spirv.array<3 x f32>) "None" {
     // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [1.000000e+00 : f32] : !spirv.array<3 x f32>
     %0 = spirv.Constant [1.0 : f32, 1.0 : f32, 1.0 : f32] : !spirv.array<3 x f32>
     spirv.ReturnValue %0 : !spirv.array<3 x f32>
   }
-}
 
-// -----
-
-spirv.module Logical GLSL450 {
   spirv.func @splat_array_of_splat_array_of_f32() -> (!spirv.array<2 x !spirv.array<3 x f32>>) "None" {
     // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [3.000000e+00 : f32] : !spirv.array<2 x !spirv.array<3 x f32>>
     %0 = spirv.Constant [[3.0 : f32, 3.0 : f32, 3.0 : f32], [3.0 : f32, 3.0 : f32, 3.0 : f32]] : !spirv.array<2 x !spirv.array<3 x f32>>
     spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x f32>>
   }
-}
-
-// -----
 
-spirv.module Logical GLSL450 {
   spirv.func @splat_array_of_non_splat_array_of_f32() -> (!spirv.array<2 x !spirv.array<3 x f32>>) "None" {
     // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate {{\[}}[1.000000e+00 : f32, 2.000000e+00 : f32, 3.000000e+00 : f32]] : !spirv.array<2 x !spirv.array<3 x f32>>
     %0 = spirv.Constant [[1.0 : f32, 2.0 : f32, 3.0 : f32], [1.0 : f32, 2.0 : f32, 3.0 : f32]] : !spirv.array<2 x !spirv.array<3 x f32>>
     spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x f32>>
   }
-}
 
-// -----
-
-spirv.module Logical GLSL450 {
   spirv.func @splat_array_of_vectors_of_f32() -> (!spirv.array<2xvector<2xf32>>) "None" {
     // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<[1.000000e+00, 2.000000e+00]> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>>
     %0 = spirv.Constant [dense<[1.0, 2.0]> : vector<2xf32>, dense<[1.0, 2.0]> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>>
     spirv.ReturnValue %0 : !spirv.array<2 x vector<2xf32>>
   }
-}
-
-// -----
 
-spirv.module Logical GLSL450 {
   spirv.func @splat_array_of_splat_vectors_of_f32() -> (!spirv.array<2 x vector<2xf32>>) "None" {
     // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2.000000e+00 : f32] : !spirv.array<2 x vector<2xf32>>
     %0 = spirv.Constant [dense<2.0> : vector<2xf32>, dense<2.0> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>>
     spirv.ReturnValue %0 : !spirv.array<2 x vector<2xf32>>
   }
-}
 
-// -----
-
-spirv.module Logical GLSL450 {
   spirv.func @splat_tensor_of_f32() -> (!spirv.array<2 x !spirv.array<3 x f32>>) "None" {
     // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [3.000000e+00 : f32] : !spirv.array<2 x !spirv.array<3 x f32>>
     %0 = spirv.Constant dense<3.0> : tensor<2x3xf32> : !spirv.array<2 x !spirv.array<3 x f32>>
     spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x f32>>
   }
-}
-
-// -----
 
-spirv.module Logical GLSL450 {
   spirv.func @splat_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
     // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2.000000e+00 : f32] : !spirv.arm.tensor<2x3xf32>
     %0 = spirv.Constant dense<2.0> : !spirv.arm.tensor<2x3xf32>
     spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
   }
-}
-
-// -----
 
-spirv.module Logical GLSL450 {
   spirv.func @array_of_one_i32() -> (!spirv.array<1 x i32>) "None" {
     // CHECK-NOT: spirv.EXT.ConstantCompositeReplicate
     %0 = spirv.Constant [1 : i32] : !spirv.array<1 x i32>
     spirv.ReturnValue %0 : !spirv.array<1 x i32>
   }
-}
-
-// -----
 
-spirv.module Logical GLSL450 {
   spirv.func @arm_tensor_of_one_i32() -> (!spirv.arm.tensor<1xi32>) "None" {
     // CHECK-NOT: spirv.EXT.ConstantCompositeReplicate
     %0 = spirv.Constant dense<1> : !spirv.arm.tensor<1xi32>
     spirv.ReturnValue %0 : !spirv.arm.tensor<1xi32>
   }
-}
 
-// -----
-
-spirv.module Logical GLSL450 {
   spirv.func @non_splat_vector_of_i32() -> (vector<3xi32>) "None" {
     // CHECK-NOT: spirv.EXT.ConstantCompositeReplicate
     %0 = spirv.Constant dense<[0, 1, 2]> : vector<3xi32>
     spirv.ReturnValue %0 : vector<3xi32>
   }
-}
 
-// -----
-
-spirv.module Logical GLSL450 {
   spirv.func @non_splat_array_of_vectors_of_i32() -> (!spirv.array<2xvector<2xi32>>) "None" {
     // CHECK-NOT: spirv.EXT.ConstantCompositeReplicate
     %0 = spirv.Constant [dense<[1, 2]> : vector<2xi32>, dense<[1, 3]> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>>
     spirv.ReturnValue %0 : !spirv.array<2 x vector<2xi32>>
   }
-}
-
-// -----
 
-spirv.module Logical GLSL450 {
   spirv.func @array_of_one_f32() -> (!spirv.array<1 x f32>) "None" {
     // CHECK-NOT: spirv.EXT.ConstantCompositeReplicate
     %0 = spirv.Constant [1.0 : f32] : !spirv.array<1 x f32>
     spirv.ReturnValue %0 : !spirv.array<1 x f32>
   }
-}
 
-// -----
-
-spirv.module Logical GLSL450 {
   spirv.func @arm_tensor_of_one_f32() -> (!spirv.arm.tensor<1xf32>) "None" {
     // CHECK-NOT: spirv.EXT.ConstantCompositeReplicate
     %0 = spirv.Constant dense<1.0> : !spirv.arm.tensor<1xf32>

>From b718e34d6a35f83c534121075d1459a5af5bd521 Mon Sep 17 00:00:00 2001
From: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian at arm.com>
Date: Thu, 17 Jul 2025 11:22:13 +0100
Subject: [PATCH 5/9] Addressing further code review comments

Signed-off-by: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian at arm.com>
---
 ...nvertToReplicatedConstantCompositePass.cpp | 26 +++++++------------
 1 file changed, 9 insertions(+), 17 deletions(-)

diff --git a/mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp
index acd66002746aa..1856f8017ab33 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp
@@ -1,4 +1,4 @@
-//===- ConvertToReplicatedConstantCompositePass.cpp --------------------===//
+//===- ConvertToReplicatedConstantCompositePass.cpp -----------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -40,15 +40,9 @@ getSplatAttributeAndCount(Attribute valueAttr) {
   }
 
   if (attr) {
-    if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
-      if (isa<spirv::CompositeType>(typedAttr.getType())) {
-        std::pair<Attribute, uint32_t> newSplatAttrAndCount =
-            getSplatAttributeAndCount(attr);
-        if (newSplatAttrAndCount.first) {
-          return newSplatAttrAndCount;
-        }
-      }
-    } else if (isa<ArrayAttr>(attr)) {
+    auto typedAttr = dyn_cast<TypedAttr>(attr);
+    if ((typedAttr && isa<spirv::CompositeType>(typedAttr.getType())) ||
+        isa<ArrayAttr>(attr)) {
       std::pair<Attribute, uint32_t> newSplatAttrAndCount =
           getSplatAttributeAndCount(attr);
       if (newSplatAttrAndCount.first) {
@@ -69,17 +63,16 @@ struct ConstantOpConversion final : OpRewritePattern<spirv::ConstantOp> {
     if (!compositeType)
       return rewriter.notifyMatchFailure(op, "not a composite constant");
 
-    std::pair<Attribute, uint32_t> splatAttrAndCount =
-        getSplatAttributeAndCount(op.getValue());
-    if (!splatAttrAndCount.first)
+    auto [splattAttr, splatCount] = getSplatAttributeAndCount(op.getValue());
+    if (!splattAttr)
       return rewriter.notifyMatchFailure(op, "composite is not splat");
 
-    if (splatAttrAndCount.second == 1)
+    if (splatCount == 1)
       return rewriter.notifyMatchFailure(op,
                                          "composite has only one constituent");
 
     rewriter.replaceOpWithNewOp<spirv::EXTConstantCompositeReplicateOp>(
-        op, op.getType(), splatAttrAndCount.first);
+        op, op.getType(), splattAttr);
 
     return success();
   }
@@ -104,8 +97,7 @@ struct SpecConstantCompositeOpConversion final
                              std::not_equal_to<>()) == constituents.end()))
       return rewriter.notifyMatchFailure(op, "composite is not splat");
 
-    auto splatConstituent =
-        dyn_cast<FlatSymbolRefAttr>(op.getConstituents()[0]);
+    auto splatConstituent = dyn_cast<FlatSymbolRefAttr>(constituents[0]);
     if (!splatConstituent)
       return rewriter.notifyMatchFailure(
           op, "expected flat symbol reference for splat constituent");

>From 8234e682f0562f82e63ab02d00db31fc82117d8e Mon Sep 17 00:00:00 2001
From: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian at arm.com>
Date: Thu, 17 Jul 2025 11:30:06 +0100
Subject: [PATCH 6/9] Minor typo fix

Signed-off-by: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian at arm.com>
---
 .../Transforms/ConvertToReplicatedConstantCompositePass.cpp | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp
index 1856f8017ab33..798a405b4df69 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp
@@ -63,8 +63,8 @@ struct ConstantOpConversion final : OpRewritePattern<spirv::ConstantOp> {
     if (!compositeType)
       return rewriter.notifyMatchFailure(op, "not a composite constant");
 
-    auto [splattAttr, splatCount] = getSplatAttributeAndCount(op.getValue());
-    if (!splattAttr)
+    auto [splatAttr, splatCount] = getSplatAttributeAndCount(op.getValue());
+    if (!splatAttr)
       return rewriter.notifyMatchFailure(op, "composite is not splat");
 
     if (splatCount == 1)
@@ -72,7 +72,7 @@ struct ConstantOpConversion final : OpRewritePattern<spirv::ConstantOp> {
                                          "composite has only one constituent");
 
     rewriter.replaceOpWithNewOp<spirv::EXTConstantCompositeReplicateOp>(
-        op, op.getType(), splattAttr);
+        op, op.getType(), splatAttr);
 
     return success();
   }

>From 568bccb8140222caa2e701f708fe0c613e805ef8 Mon Sep 17 00:00:00 2001
From: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian at arm.com>
Date: Thu, 17 Jul 2025 16:37:22 +0100
Subject: [PATCH 7/9] Addressing further code review comments

Signed-off-by: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian at arm.com>
---
 .../mlir/Dialect/SPIRV/Transforms/Passes.td   |  2 +-
 ...nvertToReplicatedConstantCompositePass.cpp | 30 +++++++------------
 .../replicated-const-composites.mlir          | 16 +++-------
 3 files changed, 16 insertions(+), 32 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td b/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td
index a4418085b5ce5..2016bea43fc8a 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td
@@ -78,7 +78,7 @@ def SPIRVWebGPUPreparePass : Pass<"spirv-webgpu-prepare", "spirv::ModuleOp"> {
 }
 
 def SPIRVReplicatedConstantCompositePass
-    : Pass<"spirv-convert-to-replicated-const-composite", "spirv::ModuleOp"> {
+    : Pass<"spirv-promote-to-replicated-constants", "spirv::ModuleOp"> {
   let summary = "Convert splat composite constants and spec constants to "
                 "corresponding replicated constant composite ops defined by "
                 "SPV_EXT_replicated_composites";
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp
index 798a405b4df69..c4df57072a55d 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp
@@ -23,28 +23,22 @@ namespace mlir::spirv {
 namespace {
 
 static std::pair<Attribute, uint32_t>
-getSplatAttributeAndCount(Attribute valueAttr) {
+getSplatAttrAndNumElements(Attribute valueAttr) {
   Attribute attr;
   uint32_t splatCount = 0;
-  if (auto denseAttr = dyn_cast<DenseElementsAttr>(valueAttr)) {
-    if (denseAttr.isSplat()) {
-      attr = denseAttr.getSplatValue<Attribute>();
-      splatCount = denseAttr.size();
-    }
-  } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
-    if (std::adjacent_find(arrayAttr.begin(), arrayAttr.end(),
-                           std::not_equal_to<>()) == arrayAttr.end()) {
+  if (auto splatAttr = dyn_cast<SplatElementsAttr>(valueAttr)) {
+    return {splatAttr.getSplatValue<Attribute>(), splatAttr.size()};
+  }
+  if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
+    if (llvm::all_equal(arrayAttr)) {
       attr = arrayAttr[0];
       splatCount = arrayAttr.size();
     }
-  }
 
-  if (attr) {
-    auto typedAttr = dyn_cast<TypedAttr>(attr);
-    if ((typedAttr && isa<spirv::CompositeType>(typedAttr.getType())) ||
-        isa<ArrayAttr>(attr)) {
+    if (attr) {
+      // Find the inner-most splat value for array of composites
       std::pair<Attribute, uint32_t> newSplatAttrAndCount =
-          getSplatAttributeAndCount(attr);
+          getSplatAttrAndNumElements(attr);
       if (newSplatAttrAndCount.first) {
         return newSplatAttrAndCount;
       }
@@ -63,7 +57,7 @@ struct ConstantOpConversion final : OpRewritePattern<spirv::ConstantOp> {
     if (!compositeType)
       return rewriter.notifyMatchFailure(op, "not a composite constant");
 
-    auto [splatAttr, splatCount] = getSplatAttributeAndCount(op.getValue());
+    auto [splatAttr, splatCount] = getSplatAttrAndNumElements(op.getValue());
     if (!splatAttr)
       return rewriter.notifyMatchFailure(op, "composite is not splat");
 
@@ -73,7 +67,6 @@ struct ConstantOpConversion final : OpRewritePattern<spirv::ConstantOp> {
 
     rewriter.replaceOpWithNewOp<spirv::EXTConstantCompositeReplicateOp>(
         op, op.getType(), splatAttr);
-
     return success();
   }
 };
@@ -93,8 +86,7 @@ struct SpecConstantCompositeOpConversion final
       return rewriter.notifyMatchFailure(op,
                                          "composite has only one consituent");
 
-    if (!(std::adjacent_find(constituents.begin(), constituents.end(),
-                             std::not_equal_to<>()) == constituents.end()))
+    if (!(llvm::all_equal(constituents)))
       return rewriter.notifyMatchFailure(op, "composite is not splat");
 
     auto splatConstituent = dyn_cast<FlatSymbolRefAttr>(constituents[0]);
diff --git a/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir b/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir
index b343d0bc73b4f..b3a8bd830c668 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir
@@ -1,6 +1,6 @@
-// RUN: mlir-opt --spirv-convert-to-replicated-const-composite --split-input-file --verify-diagnostics %s | FileCheck %s
+// RUN: mlir-opt --spirv-promote-to-replicated-constants --split-input-file %s | FileCheck %s
 
-spirv.module Logical GLSL450 {
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompositesEXT], [SPV_EXT_replicated_composites]> {
   spirv.func @splat_vector_of_i32() -> (vector<3xi32>) "None" {
     // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2 : i32] : vector<3xi32>
     %0 = spirv.Constant dense<2> : vector<3xi32>
@@ -132,21 +132,13 @@ spirv.module Logical GLSL450 {
     %0 = spirv.Constant dense<1.0> : !spirv.arm.tensor<1xf32>
     spirv.ReturnValue %0 : !spirv.arm.tensor<1xf32>
   }
-}
-
-// -----
 
-spirv.module Logical GLSL450 {
   spirv.func @non_splat_vector_of_f32() -> (vector<3xf32>) "None" {
     // CHECK-NOT: spirv.EXT.ConstantCompositeReplicate
     %0 = spirv.Constant dense<[0.0, 1.0, 2.0]> : vector<3xf32>
     spirv.ReturnValue %0 : vector<3xf32>
   }
-}
 
-// -----
-
-spirv.module Logical GLSL450 {
   spirv.func @non_splat_array_of_vectors_of_f32() -> (!spirv.array<2xvector<2xf32>>) "None" {
     // CHECK-NOT: spirv.EXT.ConstantCompositeReplicate
     %0 = spirv.Constant [dense<[1.0, 2.0]> : vector<2xf32>, dense<[1.0, 3.0]> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>>
@@ -156,7 +148,7 @@ spirv.module Logical GLSL450 {
 
 // -----
 
-spirv.module Logical GLSL450 {
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompositesEXT], [SPV_EXT_replicated_composites]> {
 
   spirv.SpecConstant @sc_i32_1 = 1 : i32
 
@@ -216,4 +208,4 @@ spirv.module Logical GLSL450 {
 
   // CHECK-NOT: spirv.EXT.SpecConstantCompositeReplicate
   spirv.SpecConstantComposite @scc_struct_of_i32_and_f32 (@sc_i32_1, @sc_i32_1, @sc_f32_1) : !spirv.struct<(i32, i32, f32)>
-}
\ No newline at end of file
+}

>From ae1eb18744012ec90467c0d2d267f12c9650fb54 Mon Sep 17 00:00:00 2001
From: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian at arm.com>
Date: Thu, 17 Jul 2025 22:07:57 +0100
Subject: [PATCH 8/9] Addressing further code review comments

Signed-off-by: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian at arm.com>
---
 ...nvertToReplicatedConstantCompositePass.cpp | 25 ++++++++-----------
 1 file changed, 11 insertions(+), 14 deletions(-)

diff --git a/mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp
index c4df57072a55d..8ca615499404b 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp
@@ -25,27 +25,24 @@ namespace {
 static std::pair<Attribute, uint32_t>
 getSplatAttrAndNumElements(Attribute valueAttr) {
   Attribute attr;
-  uint32_t splatCount = 0;
+  uint32_t numElements = 0;
   if (auto splatAttr = dyn_cast<SplatElementsAttr>(valueAttr)) {
     return {splatAttr.getSplatValue<Attribute>(), splatAttr.size()};
   }
   if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
     if (llvm::all_equal(arrayAttr)) {
       attr = arrayAttr[0];
-      splatCount = arrayAttr.size();
-    }
+      numElements = arrayAttr.size();
 
-    if (attr) {
       // Find the inner-most splat value for array of composites
-      std::pair<Attribute, uint32_t> newSplatAttrAndCount =
-          getSplatAttrAndNumElements(attr);
-      if (newSplatAttrAndCount.first) {
-        return newSplatAttrAndCount;
+      auto [newAttr, newNumElements] = getSplatAttrAndNumElements(attr);
+      if (newAttr) {
+        return {newAttr, numElements * newNumElements};
       }
     }
   }
 
-  return {attr, splatCount};
+  return {attr, numElements};
 }
 
 struct ConstantOpConversion final : OpRewritePattern<spirv::ConstantOp> {
@@ -57,16 +54,16 @@ struct ConstantOpConversion final : OpRewritePattern<spirv::ConstantOp> {
     if (!compositeType)
       return rewriter.notifyMatchFailure(op, "not a composite constant");
 
-    auto [splatAttr, splatCount] = getSplatAttrAndNumElements(op.getValue());
-    if (!splatAttr)
+    auto [attr, numElements] = getSplatAttrAndNumElements(op.getValue());
+    if (!attr)
       return rewriter.notifyMatchFailure(op, "composite is not splat");
 
-    if (splatCount == 1)
+    if (numElements == 1)
       return rewriter.notifyMatchFailure(op,
                                          "composite has only one constituent");
 
     rewriter.replaceOpWithNewOp<spirv::EXTConstantCompositeReplicateOp>(
-        op, op.getType(), splatAttr);
+        op, op.getType(), attr);
     return success();
   }
 };
@@ -86,7 +83,7 @@ struct SpecConstantCompositeOpConversion final
       return rewriter.notifyMatchFailure(op,
                                          "composite has only one consituent");
 
-    if (!(llvm::all_equal(constituents)))
+    if (!llvm::all_equal(constituents))
       return rewriter.notifyMatchFailure(op, "composite is not splat");
 
     auto splatConstituent = dyn_cast<FlatSymbolRefAttr>(constituents[0]);

>From 6ea73e386f002cbd77bd39a3671ab97e44d82c3a Mon Sep 17 00:00:00 2001
From: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian at arm.com>
Date: Fri, 18 Jul 2025 12:27:58 +0100
Subject: [PATCH 9/9] Slight change of logic for value type detection

Signed-off-by: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian at arm.com>
---
 ...nvertToReplicatedConstantCompositePass.cpp | 32 ++++++---
 .../replicated-const-composites.mlir          | 72 +++++++++++++++++++
 2 files changed, 96 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp
index 8ca615499404b..faa0165271c60 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp
@@ -22,20 +22,39 @@ namespace mlir::spirv {
 
 namespace {
 
+static Type getArrayElemType(Attribute attr) {
+  if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
+    return typedAttr.getType();
+  }
+
+  if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
+    return ArrayType::get(getArrayElemType(arrayAttr[0]), arrayAttr.size());
+  }
+
+  return nullptr;
+}
+
 static std::pair<Attribute, uint32_t>
-getSplatAttrAndNumElements(Attribute valueAttr) {
+getSplatAttrAndNumElements(Attribute valueAttr, Type valueType) {
   Attribute attr;
-  uint32_t numElements = 0;
+  uint32_t numElements = 1;
+
+  auto compositeType = dyn_cast_or_null<spirv::CompositeType>(valueType);
+  if (!compositeType)
+    return {nullptr, 1};
+
   if (auto splatAttr = dyn_cast<SplatElementsAttr>(valueAttr)) {
     return {splatAttr.getSplatValue<Attribute>(), splatAttr.size()};
   }
+
   if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
     if (llvm::all_equal(arrayAttr)) {
       attr = arrayAttr[0];
       numElements = arrayAttr.size();
 
       // Find the inner-most splat value for array of composites
-      auto [newAttr, newNumElements] = getSplatAttrAndNumElements(attr);
+      auto [newAttr, newNumElements] =
+          getSplatAttrAndNumElements(attr, getArrayElemType(attr));
       if (newAttr) {
         return {newAttr, numElements * newNumElements};
       }
@@ -50,11 +69,8 @@ struct ConstantOpConversion final : OpRewritePattern<spirv::ConstantOp> {
 
   LogicalResult matchAndRewrite(spirv::ConstantOp op,
                                 PatternRewriter &rewriter) const override {
-    auto compositeType = dyn_cast_or_null<spirv::CompositeType>(op.getType());
-    if (!compositeType)
-      return rewriter.notifyMatchFailure(op, "not a composite constant");
-
-    auto [attr, numElements] = getSplatAttrAndNumElements(op.getValue());
+    auto [attr, numElements] =
+        getSplatAttrAndNumElements(op.getValue(), op.getType());
     if (!attr)
       return rewriter.notifyMatchFailure(op, "composite is not splat");
 
diff --git a/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir b/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir
index b3a8bd830c668..56e26eee83ff9 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir
@@ -49,6 +49,36 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompos
     spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
   }
 
+  spirv.func @array_of_splat_array_of_non_splat_vectors_of_i32() -> (!spirv.array<1 x !spirv.array<2 x vector<2xi32>>>) "None" {
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<[1, 2]> : vector<2xi32>] : !spirv.array<1 x !spirv.array<2 x vector<2xi32>>
+    %0 = spirv.Constant [[dense<[1, 2]> : vector<2xi32>, dense<[1, 2]> : vector<2xi32>]] : !spirv.array<1 x !spirv.array<2 x vector<2xi32>>>
+    spirv.ReturnValue %0 : !spirv.array<1 x !spirv.array<2 x vector<2xi32>>>
+  }
+
+  spirv.func @array_of_one_splat_array_of_vector_of_one_i32() -> !spirv.array<1 x !spirv.array<2 x vector<1xi32>>> "None" {
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<1> : vector<1xi32>] : !spirv.array<1 x !spirv.array<2 x vector<1xi32>
+    %cst = spirv.Constant [[dense<1> : vector<1xi32>], [dense<1> : vector<1xi32>]] : !spirv.array<1 x !spirv.array<2 x vector<1xi32>>>
+    spirv.ReturnValue %cst : !spirv.array<1 x !spirv.array<2 x vector<1xi32>>>
+  }
+
+  spirv.func @splat_array_of_array_of_one_vector_of_one_i32() -> (!spirv.array<2 x !spirv.array<1 x vector<1xi32>>>) "None" {
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<1> : vector<1xi32>] : !spirv.array<2 x !spirv.array<1 x vector<1xi32>>>
+    %0 = spirv.Constant [[dense<1> : vector<1xi32>], [dense<1> : vector<1xi32>]] : !spirv.array<2 x !spirv.array<1 x vector<1xi32>>>
+    spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<1 x vector<1xi32>>>
+  }
+
+  spirv.func @array_of_one_array_of_one_splat_vector_of_i32() -> (!spirv.array<1 x !spirv.array<1 x vector<2xi32>>>) "None" {
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [1 : i32] : !spirv.array<1 x !spirv.array<1 x vector<2xi32>>>
+    %0 = spirv.Constant [[dense<1> : vector<2xi32>]] : !spirv.array<1 x !spirv.array<1 x vector<2xi32>>>
+    spirv.ReturnValue %0 : !spirv.array<1 x !spirv.array<1 x vector<2xi32>>>
+  }
+
+  spirv.func @splat_array_of_splat_array_of_non_splat_array_of_i32() -> (!spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>>) "None" {
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate {{\[}}[1 : i32, 2 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>>
+    %0 = spirv.Constant [[[1 : i32, 2 : i32, 3 : i32], [1 : i32, 2 : i32, 3 : i32]], [[1 : i32, 2 : i32, 3 : i32], [1 : i32, 2 : i32, 3 : i32]]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>>
+    spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>>
+  }
+
   spirv.func @splat_vector_of_f32() -> (vector<3xf32>) "None" {
     // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2.000000e+00 : f32] : vector<3xf32>
     %0 = spirv.Constant dense<2.0> : vector<3xf32>
@@ -97,6 +127,36 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompos
     spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
   }
 
+  spirv.func @array_of_splat_array_of_non_splat_vectors_of_f32() -> (!spirv.array<1 x !spirv.array<2 x vector<2xf32>>>) "None" {
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<[1.000000e+00, 2.000000e+00]> : vector<2xf32>] : !spirv.array<1 x !spirv.array<2 x vector<2xf32>>
+    %0 = spirv.Constant [[dense<[1.0, 2.0]> : vector<2xf32>, dense<[1.0, 2.0]> : vector<2xf32>]] : !spirv.array<1 x !spirv.array<2 x vector<2xf32>>>
+    spirv.ReturnValue %0 : !spirv.array<1 x !spirv.array<2 x vector<2xf32>>>
+  }
+
+  spirv.func @array_of_one_splat_array_of_vector_of_one_f32() -> !spirv.array<1 x !spirv.array<2 x vector<1xf32>>> "None" {
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<1.000000e+00> : vector<1xf32>] : !spirv.array<1 x !spirv.array<2 x vector<1xf32>
+    %cst = spirv.Constant [[dense<1.0> : vector<1xf32>], [dense<1.0> : vector<1xf32>]] : !spirv.array<1 x !spirv.array<2 x vector<1xf32>>>
+    spirv.ReturnValue %cst : !spirv.array<1 x !spirv.array<2 x vector<1xf32>>>
+  }
+
+  spirv.func @splat_array_of_array_of_one_vector_of_one_f32() -> (!spirv.array<2 x !spirv.array<1 x vector<1xf32>>>) "None" {
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<1.000000e+00> : vector<1xf32>] : !spirv.array<2 x !spirv.array<1 x vector<1xf32>>>
+    %0 = spirv.Constant [[dense<1.0> : vector<1xf32>], [dense<1.0> : vector<1xf32>]] : !spirv.array<2 x !spirv.array<1 x vector<1xf32>>>
+    spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<1 x vector<1xf32>>>
+  }
+
+  spirv.func @array_of_one_array_of_one_splat_vector_of_f32() -> (!spirv.array<1 x !spirv.array<1 x vector<2xf32>>>) "None" {
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [1.000000e+00 : f32] : !spirv.array<1 x !spirv.array<1 x vector<2xf32>>>
+    %0 = spirv.Constant [[dense<1.0> : vector<2xf32>]] : !spirv.array<1 x !spirv.array<1 x vector<2xf32>>>
+    spirv.ReturnValue %0 : !spirv.array<1 x !spirv.array<1 x vector<2xf32>>>
+  }
+
+  spirv.func @splat_array_of_splat_array_of_non_splat_array_of_f32() -> (!spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>>) "None" {
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate {{\[}}[1.000000e+00 : f32, 2.000000e+00 : f32, 3.000000e+00 : f32]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>>
+    %0 = spirv.Constant [[[1.0 : f32, 2.0 : f32, 3.0 : f32], [1.0 : f32, 2.0 : f32, 3.0 : f32]], [[1.0 : f32, 2.0 : f32, 3.0 : f32], [1.0 : f32, 2.0 : f32, 3.0 : f32]]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>>
+    spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>>
+  }
+
   spirv.func @array_of_one_i32() -> (!spirv.array<1 x i32>) "None" {
     // CHECK-NOT: spirv.EXT.ConstantCompositeReplicate
     %0 = spirv.Constant [1 : i32] : !spirv.array<1 x i32>
@@ -144,6 +204,18 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompos
     %0 = spirv.Constant [dense<[1.0, 2.0]> : vector<2xf32>, dense<[1.0, 3.0]> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>>
     spirv.ReturnValue %0 : !spirv.array<2 x vector<2xf32>>
   }
+
+  spirv.func @array_of_one_array_of_one_non_splat_vector_of_i32() -> (!spirv.array<1 x !spirv.array<1 x vector<2xi32>>>) "None" {
+    // CHECK-NOT spirv.EXT.ConstantCompositeReplicate
+    %0 = spirv.Constant [[dense<[1, 2]> : vector<2xi32>]] : !spirv.array<1 x !spirv.array<1 x vector<2xi32>>>
+    spirv.ReturnValue %0 : !spirv.array<1 x !spirv.array<1 x vector<2xi32>>>
+  }
+  
+  spirv.func @array_of_one_array_of_one_vector_of_one_i32() -> (!spirv.array<1 x !spirv.array<1 x vector<1xi32>>>) "None" {
+    // CHECK-NOT spirv.EXT.ConstantCompositeReplicate
+    %0 = spirv.Constant [[dense<1> : vector<1xi32>]] : !spirv.array<1 x !spirv.array<1 x vector<1xi32>>>
+    spirv.ReturnValue %0 : !spirv.array<1 x !spirv.array<1 x vector<1xi32>>>
+  }
 }
 
 // -----



More information about the Mlir-commits mailing list