[Mlir-commits] [mlir] 10518c7 - [mlir][spirv] Add conversion pass to rewrite splat constant composite… (#148910)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jul 18 09:59:43 PDT 2025


Author: Mohammadreza Ameri Mahabadian
Date: 2025-07-18T12:59:39-04:00
New Revision: 10518c76de091bf23e72a8761c1eff561ce6e074

URL: https://github.com/llvm/llvm-project/commit/10518c76de091bf23e72a8761c1eff561ce6e074
DIFF: https://github.com/llvm/llvm-project/commit/10518c76de091bf23e72a8761c1eff561ce6e074.diff

LOG: [mlir][spirv] Add conversion pass to rewrite splat constant composite… (#148910)

…s to replicated form

This adds a new SPIR-V dialect-level conversion pass
`ConversionToReplicatedConstantCompositePass`. This pass looks for splat
composite `spirv.Constant` or `spirv.SpecConstantComposite` and rewrites
them into `spirv.EXT.ConstantCompositeReplicate` or
`spirv.EXT.SpecConstantCompositeReplicate`, respectively.

---------

Signed-off-by: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian at arm.com>

Added: 
    mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp
    mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir

Modified: 
    mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td
    mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td b/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td
index 2d9befe78001d..2016bea43fc8a 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td
@@ -77,4 +77,11 @@ def SPIRVWebGPUPreparePass : Pass<"spirv-webgpu-prepare", "spirv::ModuleOp"> {
                 "and replacing with supported ones";
 }
 
+def SPIRVReplicatedConstantCompositePass
+    : Pass<"spirv-promote-to-replicated-constants", "spirv::ModuleOp"> {
+  let summary = "Convert splat composite constants and spec constants to "
+                "corresponding replicated constant composite ops defined by "
+                "SPV_EXT_replicated_composites";
+}
+
 #endif // MLIR_DIALECT_SPIRV_TRANSFORMS_PASSES

diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
index 68e0206e30a59..b947447dad46a 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
 set(LLVM_OPTIONAL_SOURCES
   CanonicalizeGLPass.cpp
+  ConvertToReplicatedConstantCompositePass.cpp
   DecorateCompositeTypeLayoutPass.cpp
   LowerABIAttributesPass.cpp
   RewriteInsertsPass.cpp
@@ -30,6 +31,7 @@ add_mlir_dialect_library(MLIRSPIRVConversion
 
 add_mlir_dialect_library(MLIRSPIRVTransforms
   CanonicalizeGLPass.cpp
+  ConvertToReplicatedConstantCompositePass.cpp
   DecorateCompositeTypeLayoutPass.cpp
   LowerABIAttributesPass.cpp
   RewriteInsertsPass.cpp

diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp
new file mode 100644
index 0000000000000..dbbe23aa08b3c
--- /dev/null
+++ b/mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp
@@ -0,0 +1,129 @@
+//===- ConvertToReplicatedConstantCompositePass.cpp -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to convert a splat composite spirv.Constant and
+// spirv.SpecConstantComposite to spirv.EXT.ConstantCompositeReplicate and
+// spirv.EXT.SpecConstantCompositeReplicate respectively.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
+
+namespace mlir::spirv {
+#define GEN_PASS_DEF_SPIRVREPLICATEDCONSTANTCOMPOSITEPASS
+#include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
+
+namespace {
+
+static Type getArrayElemType(Attribute attr) {
+  if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
+    return typedAttr.getType();
+  }
+
+  if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
+    return ArrayType::get(getArrayElemType(arrayAttr[0]), arrayAttr.size());
+  }
+
+  return nullptr;
+}
+
+static std::pair<Attribute, uint32_t>
+getSplatAttrAndNumElements(Attribute valueAttr, Type valueType) {
+  auto compositeType = dyn_cast_or_null<spirv::CompositeType>(valueType);
+  if (!compositeType)
+    return {nullptr, 1};
+
+  if (auto splatAttr = dyn_cast<SplatElementsAttr>(valueAttr)) {
+    return {splatAttr.getSplatValue<Attribute>(), splatAttr.size()};
+  }
+
+  if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
+    if (llvm::all_equal(arrayAttr)) {
+      Attribute attr = arrayAttr[0];
+      uint32_t numElements = arrayAttr.size();
+
+      // Find the inner-most splat value for array of composites
+      auto [newAttr, newNumElements] =
+          getSplatAttrAndNumElements(attr, getArrayElemType(attr));
+      if (newAttr) {
+        attr = newAttr;
+        numElements *= newNumElements;
+      }
+      return {attr, numElements};
+    }
+  }
+
+  return {nullptr, 1};
+}
+
+struct ConstantOpConversion final : OpRewritePattern<spirv::ConstantOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(spirv::ConstantOp op,
+                                PatternRewriter &rewriter) const override {
+    auto [attr, numElements] =
+        getSplatAttrAndNumElements(op.getValue(), op.getType());
+    if (!attr)
+      return rewriter.notifyMatchFailure(op, "composite is not splat");
+
+    if (numElements == 1)
+      return rewriter.notifyMatchFailure(op,
+                                         "composite has only one constituent");
+
+    rewriter.replaceOpWithNewOp<spirv::EXTConstantCompositeReplicateOp>(
+        op, op.getType(), attr);
+    return success();
+  }
+};
+
+struct SpecConstantCompositeOpConversion final
+    : OpRewritePattern<spirv::SpecConstantCompositeOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(spirv::SpecConstantCompositeOp op,
+                                PatternRewriter &rewriter) const override {
+    auto compositeType = dyn_cast_or_null<spirv::CompositeType>(op.getType());
+    if (!compositeType)
+      return rewriter.notifyMatchFailure(op, "not a composite constant");
+
+    ArrayAttr constituents = op.getConstituents();
+    if (constituents.size() == 1)
+      return rewriter.notifyMatchFailure(op,
+                                         "composite has only one consituent");
+
+    if (!llvm::all_equal(constituents))
+      return rewriter.notifyMatchFailure(op, "composite is not splat");
+
+    auto splatConstituent = dyn_cast<FlatSymbolRefAttr>(constituents[0]);
+    if (!splatConstituent)
+      return rewriter.notifyMatchFailure(
+          op, "expected flat symbol reference for splat constituent");
+
+    rewriter.replaceOpWithNewOp<spirv::EXTSpecConstantCompositeReplicateOp>(
+        op, TypeAttr::get(op.getType()), op.getSymNameAttr(), splatConstituent);
+
+    return success();
+  }
+};
+
+struct ConvertToReplicatedConstantCompositePass final
+    : spirv::impl::SPIRVReplicatedConstantCompositePassBase<
+          ConvertToReplicatedConstantCompositePass> {
+  void runOnOperation() override {
+    MLIRContext *context = &getContext();
+    RewritePatternSet patterns(context);
+    patterns.add<ConstantOpConversion, SpecConstantCompositeOpConversion>(
+        context);
+    walkAndApplyPatterns(getOperation(), std::move(patterns));
+  }
+};
+
+} // namespace
+} // namespace mlir::spirv

diff  --git a/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir b/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir
new file mode 100644
index 0000000000000..56e26eee83ff9
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir
@@ -0,0 +1,283 @@
+// RUN: mlir-opt --spirv-promote-to-replicated-constants --split-input-file %s | FileCheck %s
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompositesEXT], [SPV_EXT_replicated_composites]> {
+  spirv.func @splat_vector_of_i32() -> (vector<3xi32>) "None" {
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2 : i32] : vector<3xi32>
+    %0 = spirv.Constant dense<2> : vector<3xi32>
+    spirv.ReturnValue %0 : vector<3xi32>
+  }
+
+  spirv.func @splat_array_of_i32() -> (!spirv.array<3 x i32>) "None" {
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [1 : i32] : !spirv.array<3 x i32>
+    %0 = spirv.Constant [1 : i32, 1 : i32, 1 : i32] : !spirv.array<3 x i32>
+    spirv.ReturnValue %0 : !spirv.array<3 x i32>
+  }
+
+  spirv.func @splat_array_of_splat_array_of_i32() -> (!spirv.array<2 x !spirv.array<3 x i32>>) "None" {
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [3 : i32] : !spirv.array<2 x !spirv.array<3 x i32>>
+    %0 = spirv.Constant [[3 : i32, 3 : i32, 3 : i32], [3 : i32, 3 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<3 x i32>>
+    spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x i32>>
+  }
+
+  spirv.func @splat_array_of_non_splat_array_of_i32() -> (!spirv.array<2 x !spirv.array<3 x i32>>) "None" {
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate {{\[}}[1 : i32, 2 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<3 x i32>>
+    %0 = spirv.Constant [[1 : i32, 2 : i32, 3 : i32], [1 : i32, 2 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<3 x i32>>
+    spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x i32>>
+  }
+
+  spirv.func @splat_array_of_vectors_of_i32() -> (!spirv.array<2xvector<2xi32>>) "None" {
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<[1, 2]> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>>
+    %0 = spirv.Constant [dense<[1, 2]> : vector<2xi32>, dense<[1, 2]> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>>
+    spirv.ReturnValue %0 : !spirv.array<2 x vector<2xi32>>
+  }
+
+  spirv.func @splat_array_of_splat_vectors_of_i32() -> (!spirv.array<2 x vector<2xi32>>) "None" {
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2 : i32] : !spirv.array<2 x vector<2xi32>>
+    %0 = spirv.Constant [dense<2> : vector<2xi32>, dense<2> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>>
+    spirv.ReturnValue %0 : !spirv.array<2 x vector<2xi32>>
+  }
+
+  spirv.func @splat_tensor_of_i32() -> (!spirv.array<2 x !spirv.array<3 x i32>>) "None" {
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [3 : i32] : !spirv.array<2 x !spirv.array<3 x i32>>
+    %0 = spirv.Constant dense<3> : tensor<2x3xi32> : !spirv.array<2 x !spirv.array<3 x i32>>
+    spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x i32>>
+  }
+
+  spirv.func @splat_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2 : i32] : !spirv.arm.tensor<2x3xi32>
+    %0 = spirv.Constant dense<2> : !spirv.arm.tensor<2x3xi32>
+    spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
+  }
+
+  spirv.func @array_of_splat_array_of_non_splat_vectors_of_i32() -> (!spirv.array<1 x !spirv.array<2 x vector<2xi32>>>) "None" {
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<[1, 2]> : vector<2xi32>] : !spirv.array<1 x !spirv.array<2 x vector<2xi32>>
+    %0 = spirv.Constant [[dense<[1, 2]> : vector<2xi32>, dense<[1, 2]> : vector<2xi32>]] : !spirv.array<1 x !spirv.array<2 x vector<2xi32>>>
+    spirv.ReturnValue %0 : !spirv.array<1 x !spirv.array<2 x vector<2xi32>>>
+  }
+
+  spirv.func @array_of_one_splat_array_of_vector_of_one_i32() -> !spirv.array<1 x !spirv.array<2 x vector<1xi32>>> "None" {
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<1> : vector<1xi32>] : !spirv.array<1 x !spirv.array<2 x vector<1xi32>
+    %cst = spirv.Constant [[dense<1> : vector<1xi32>], [dense<1> : vector<1xi32>]] : !spirv.array<1 x !spirv.array<2 x vector<1xi32>>>
+    spirv.ReturnValue %cst : !spirv.array<1 x !spirv.array<2 x vector<1xi32>>>
+  }
+
+  spirv.func @splat_array_of_array_of_one_vector_of_one_i32() -> (!spirv.array<2 x !spirv.array<1 x vector<1xi32>>>) "None" {
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<1> : vector<1xi32>] : !spirv.array<2 x !spirv.array<1 x vector<1xi32>>>
+    %0 = spirv.Constant [[dense<1> : vector<1xi32>], [dense<1> : vector<1xi32>]] : !spirv.array<2 x !spirv.array<1 x vector<1xi32>>>
+    spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<1 x vector<1xi32>>>
+  }
+
+  spirv.func @array_of_one_array_of_one_splat_vector_of_i32() -> (!spirv.array<1 x !spirv.array<1 x vector<2xi32>>>) "None" {
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [1 : i32] : !spirv.array<1 x !spirv.array<1 x vector<2xi32>>>
+    %0 = spirv.Constant [[dense<1> : vector<2xi32>]] : !spirv.array<1 x !spirv.array<1 x vector<2xi32>>>
+    spirv.ReturnValue %0 : !spirv.array<1 x !spirv.array<1 x vector<2xi32>>>
+  }
+
+  spirv.func @splat_array_of_splat_array_of_non_splat_array_of_i32() -> (!spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>>) "None" {
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate {{\[}}[1 : i32, 2 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>>
+    %0 = spirv.Constant [[[1 : i32, 2 : i32, 3 : i32], [1 : i32, 2 : i32, 3 : i32]], [[1 : i32, 2 : i32, 3 : i32], [1 : i32, 2 : i32, 3 : i32]]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>>
+    spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>>
+  }
+
+  spirv.func @splat_vector_of_f32() -> (vector<3xf32>) "None" {
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2.000000e+00 : f32] : vector<3xf32>
+    %0 = spirv.Constant dense<2.0> : vector<3xf32>
+    spirv.ReturnValue %0 : vector<3xf32>
+  }
+
+  spirv.func @splat_array_of_f32() -> (!spirv.array<3 x f32>) "None" {
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [1.000000e+00 : f32] : !spirv.array<3 x f32>
+    %0 = spirv.Constant [1.0 : f32, 1.0 : f32, 1.0 : f32] : !spirv.array<3 x f32>
+    spirv.ReturnValue %0 : !spirv.array<3 x f32>
+  }
+
+  spirv.func @splat_array_of_splat_array_of_f32() -> (!spirv.array<2 x !spirv.array<3 x f32>>) "None" {
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [3.000000e+00 : f32] : !spirv.array<2 x !spirv.array<3 x f32>>
+    %0 = spirv.Constant [[3.0 : f32, 3.0 : f32, 3.0 : f32], [3.0 : f32, 3.0 : f32, 3.0 : f32]] : !spirv.array<2 x !spirv.array<3 x f32>>
+    spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x f32>>
+  }
+
+  spirv.func @splat_array_of_non_splat_array_of_f32() -> (!spirv.array<2 x !spirv.array<3 x f32>>) "None" {
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate {{\[}}[1.000000e+00 : f32, 2.000000e+00 : f32, 3.000000e+00 : f32]] : !spirv.array<2 x !spirv.array<3 x f32>>
+    %0 = spirv.Constant [[1.0 : f32, 2.0 : f32, 3.0 : f32], [1.0 : f32, 2.0 : f32, 3.0 : f32]] : !spirv.array<2 x !spirv.array<3 x f32>>
+    spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x f32>>
+  }
+
+  spirv.func @splat_array_of_vectors_of_f32() -> (!spirv.array<2xvector<2xf32>>) "None" {
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<[1.000000e+00, 2.000000e+00]> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>>
+    %0 = spirv.Constant [dense<[1.0, 2.0]> : vector<2xf32>, dense<[1.0, 2.0]> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>>
+    spirv.ReturnValue %0 : !spirv.array<2 x vector<2xf32>>
+  }
+
+  spirv.func @splat_array_of_splat_vectors_of_f32() -> (!spirv.array<2 x vector<2xf32>>) "None" {
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2.000000e+00 : f32] : !spirv.array<2 x vector<2xf32>>
+    %0 = spirv.Constant [dense<2.0> : vector<2xf32>, dense<2.0> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>>
+    spirv.ReturnValue %0 : !spirv.array<2 x vector<2xf32>>
+  }
+
+  spirv.func @splat_tensor_of_f32() -> (!spirv.array<2 x !spirv.array<3 x f32>>) "None" {
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [3.000000e+00 : f32] : !spirv.array<2 x !spirv.array<3 x f32>>
+    %0 = spirv.Constant dense<3.0> : tensor<2x3xf32> : !spirv.array<2 x !spirv.array<3 x f32>>
+    spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x f32>>
+  }
+
+  spirv.func @splat_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2.000000e+00 : f32] : !spirv.arm.tensor<2x3xf32>
+    %0 = spirv.Constant dense<2.0> : !spirv.arm.tensor<2x3xf32>
+    spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
+  }
+
+  spirv.func @array_of_splat_array_of_non_splat_vectors_of_f32() -> (!spirv.array<1 x !spirv.array<2 x vector<2xf32>>>) "None" {
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<[1.000000e+00, 2.000000e+00]> : vector<2xf32>] : !spirv.array<1 x !spirv.array<2 x vector<2xf32>>
+    %0 = spirv.Constant [[dense<[1.0, 2.0]> : vector<2xf32>, dense<[1.0, 2.0]> : vector<2xf32>]] : !spirv.array<1 x !spirv.array<2 x vector<2xf32>>>
+    spirv.ReturnValue %0 : !spirv.array<1 x !spirv.array<2 x vector<2xf32>>>
+  }
+
+  spirv.func @array_of_one_splat_array_of_vector_of_one_f32() -> !spirv.array<1 x !spirv.array<2 x vector<1xf32>>> "None" {
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<1.000000e+00> : vector<1xf32>] : !spirv.array<1 x !spirv.array<2 x vector<1xf32>
+    %cst = spirv.Constant [[dense<1.0> : vector<1xf32>], [dense<1.0> : vector<1xf32>]] : !spirv.array<1 x !spirv.array<2 x vector<1xf32>>>
+    spirv.ReturnValue %cst : !spirv.array<1 x !spirv.array<2 x vector<1xf32>>>
+  }
+
+  spirv.func @splat_array_of_array_of_one_vector_of_one_f32() -> (!spirv.array<2 x !spirv.array<1 x vector<1xf32>>>) "None" {
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<1.000000e+00> : vector<1xf32>] : !spirv.array<2 x !spirv.array<1 x vector<1xf32>>>
+    %0 = spirv.Constant [[dense<1.0> : vector<1xf32>], [dense<1.0> : vector<1xf32>]] : !spirv.array<2 x !spirv.array<1 x vector<1xf32>>>
+    spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<1 x vector<1xf32>>>
+  }
+
+  spirv.func @array_of_one_array_of_one_splat_vector_of_f32() -> (!spirv.array<1 x !spirv.array<1 x vector<2xf32>>>) "None" {
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [1.000000e+00 : f32] : !spirv.array<1 x !spirv.array<1 x vector<2xf32>>>
+    %0 = spirv.Constant [[dense<1.0> : vector<2xf32>]] : !spirv.array<1 x !spirv.array<1 x vector<2xf32>>>
+    spirv.ReturnValue %0 : !spirv.array<1 x !spirv.array<1 x vector<2xf32>>>
+  }
+
+  spirv.func @splat_array_of_splat_array_of_non_splat_array_of_f32() -> (!spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>>) "None" {
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate {{\[}}[1.000000e+00 : f32, 2.000000e+00 : f32, 3.000000e+00 : f32]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>>
+    %0 = spirv.Constant [[[1.0 : f32, 2.0 : f32, 3.0 : f32], [1.0 : f32, 2.0 : f32, 3.0 : f32]], [[1.0 : f32, 2.0 : f32, 3.0 : f32], [1.0 : f32, 2.0 : f32, 3.0 : f32]]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>>
+    spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>>
+  }
+
+  spirv.func @array_of_one_i32() -> (!spirv.array<1 x i32>) "None" {
+    // CHECK-NOT: spirv.EXT.ConstantCompositeReplicate
+    %0 = spirv.Constant [1 : i32] : !spirv.array<1 x i32>
+    spirv.ReturnValue %0 : !spirv.array<1 x i32>
+  }
+
+  spirv.func @arm_tensor_of_one_i32() -> (!spirv.arm.tensor<1xi32>) "None" {
+    // CHECK-NOT: spirv.EXT.ConstantCompositeReplicate
+    %0 = spirv.Constant dense<1> : !spirv.arm.tensor<1xi32>
+    spirv.ReturnValue %0 : !spirv.arm.tensor<1xi32>
+  }
+
+  spirv.func @non_splat_vector_of_i32() -> (vector<3xi32>) "None" {
+    // CHECK-NOT: spirv.EXT.ConstantCompositeReplicate
+    %0 = spirv.Constant dense<[0, 1, 2]> : vector<3xi32>
+    spirv.ReturnValue %0 : vector<3xi32>
+  }
+
+  spirv.func @non_splat_array_of_vectors_of_i32() -> (!spirv.array<2xvector<2xi32>>) "None" {
+    // CHECK-NOT: spirv.EXT.ConstantCompositeReplicate
+    %0 = spirv.Constant [dense<[1, 2]> : vector<2xi32>, dense<[1, 3]> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>>
+    spirv.ReturnValue %0 : !spirv.array<2 x vector<2xi32>>
+  }
+
+  spirv.func @array_of_one_f32() -> (!spirv.array<1 x f32>) "None" {
+    // CHECK-NOT: spirv.EXT.ConstantCompositeReplicate
+    %0 = spirv.Constant [1.0 : f32] : !spirv.array<1 x f32>
+    spirv.ReturnValue %0 : !spirv.array<1 x f32>
+  }
+
+  spirv.func @arm_tensor_of_one_f32() -> (!spirv.arm.tensor<1xf32>) "None" {
+    // CHECK-NOT: spirv.EXT.ConstantCompositeReplicate
+    %0 = spirv.Constant dense<1.0> : !spirv.arm.tensor<1xf32>
+    spirv.ReturnValue %0 : !spirv.arm.tensor<1xf32>
+  }
+
+  spirv.func @non_splat_vector_of_f32() -> (vector<3xf32>) "None" {
+    // CHECK-NOT: spirv.EXT.ConstantCompositeReplicate
+    %0 = spirv.Constant dense<[0.0, 1.0, 2.0]> : vector<3xf32>
+    spirv.ReturnValue %0 : vector<3xf32>
+  }
+
+  spirv.func @non_splat_array_of_vectors_of_f32() -> (!spirv.array<2xvector<2xf32>>) "None" {
+    // CHECK-NOT: spirv.EXT.ConstantCompositeReplicate
+    %0 = spirv.Constant [dense<[1.0, 2.0]> : vector<2xf32>, dense<[1.0, 3.0]> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>>
+    spirv.ReturnValue %0 : !spirv.array<2 x vector<2xf32>>
+  }
+
+  spirv.func @array_of_one_array_of_one_non_splat_vector_of_i32() -> (!spirv.array<1 x !spirv.array<1 x vector<2xi32>>>) "None" {
+    // CHECK-NOT spirv.EXT.ConstantCompositeReplicate
+    %0 = spirv.Constant [[dense<[1, 2]> : vector<2xi32>]] : !spirv.array<1 x !spirv.array<1 x vector<2xi32>>>
+    spirv.ReturnValue %0 : !spirv.array<1 x !spirv.array<1 x vector<2xi32>>>
+  }
+  
+  spirv.func @array_of_one_array_of_one_vector_of_one_i32() -> (!spirv.array<1 x !spirv.array<1 x vector<1xi32>>>) "None" {
+    // CHECK-NOT spirv.EXT.ConstantCompositeReplicate
+    %0 = spirv.Constant [[dense<1> : vector<1xi32>]] : !spirv.array<1 x !spirv.array<1 x vector<1xi32>>>
+    spirv.ReturnValue %0 : !spirv.array<1 x !spirv.array<1 x vector<1xi32>>>
+  }
+}
+
+// -----
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompositesEXT], [SPV_EXT_replicated_composites]> {
+
+  spirv.SpecConstant @sc_i32_1 = 1 : i32
+
+  // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_array_of_i32 (@sc_i32_1) : !spirv.array<3 x i32>
+  spirv.SpecConstantComposite @scc_splat_array_of_i32 (@sc_i32_1, @sc_i32_1, @sc_i32_1) : !spirv.array<3 x i32>
+
+  // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_struct_of_i32 (@sc_i32_1) : !spirv.struct<(i32, i32, i32)>
+  spirv.SpecConstantComposite @scc_splat_struct_of_i32 (@sc_i32_1, @sc_i32_1, @sc_i32_1) : !spirv.struct<(i32, i32, i32)>
+
+  // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_vector_of_i32 (@sc_i32_1) : vector<3xi32>
+  spirv.SpecConstantComposite @scc_splat_vector_of_i32 (@sc_i32_1, @sc_i32_1, @sc_i32_1) : vector<3 x i32>
+
+  // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_arm_tensor_of_i32 (@sc_i32_1) : !spirv.arm.tensor<3xi32>
+  spirv.SpecConstantComposite @scc_splat_arm_tensor_of_i32 (@sc_i32_1, @sc_i32_1, @sc_i32_1) : !spirv.arm.tensor<3xi32>
+
+  spirv.SpecConstant @sc_f32_1 = 1.0 : f32
+
+  // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_array_of_f32 (@sc_f32_1) : !spirv.array<3 x f32>
+  spirv.SpecConstantComposite @scc_splat_array_of_f32 (@sc_f32_1, @sc_f32_1, @sc_f32_1) : !spirv.array<3 x f32>
+
+  // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_struct_of_f32 (@sc_f32_1) : !spirv.struct<(f32, f32, f32)>
+  spirv.SpecConstantComposite @scc_splat_struct_of_f32 (@sc_f32_1, @sc_f32_1, @sc_f32_1) : !spirv.struct<(f32, f32, f32)>
+
+  // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_vector_of_f32 (@sc_f32_1) : vector<3xf32>
+  spirv.SpecConstantComposite @scc_splat_vector_of_f32 (@sc_f32_1, @sc_f32_1, @sc_f32_1) : vector<3 x f32>
+
+  // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_arm_tensor_of_f32 (@sc_f32_1) : !spirv.arm.tensor<3xf32>
+  spirv.SpecConstantComposite @scc_splat_arm_tensor_of_f32 (@sc_f32_1, @sc_f32_1, @sc_f32_1) : !spirv.arm.tensor<3xf32>
+
+  spirv.SpecConstant @sc_i32_2 = 2 : i32
+
+  // CHECK-NOT: spirv.EXT.SpecConstantCompositeReplicate
+  spirv.SpecConstantComposite @scc_array_of_one_i32 (@sc_i32_1) : !spirv.array<1 x i32>
+
+  // CHECK-NOT: spirv.EXT.SpecConstantCompositeReplicate
+  spirv.SpecConstantComposite @scc_arm_tensor_of_one_i32 (@sc_i32_1) : !spirv.arm.tensor<1xi32>
+
+  // CHECK-NOT: spirv.EXT.SpecConstantCompositeReplicate
+  spirv.SpecConstantComposite @scc_non_splat_vector_of_i32 (@sc_i32_1, @sc_i32_1, @sc_i32_2) : vector<3 x i32>
+
+  // CHECK-NOT: spirv.EXT.SpecConstantCompositeReplicate
+  spirv.SpecConstantComposite @scc_non_splat_arm_tensor_of_i32 (@sc_i32_2, @sc_i32_1, @sc_i32_1) : !spirv.arm.tensor<3xi32>
+
+  spirv.SpecConstant @sc_f32_2 = 2.0 : f32
+
+  // CHECK-NOT: spirv.EXT.SpecConstantCompositeReplicate
+  spirv.SpecConstantComposite @scc_array_of_one_f32 (@sc_f32_1) : !spirv.array<1 x f32>
+
+  // CHECK-NOT: spirv.EXT.SpecConstantCompositeReplicate
+  spirv.SpecConstantComposite @scc_arm_tensor_of_one_f32 (@sc_f32_1) : !spirv.arm.tensor<1xf32>
+
+  // CHECK-NOT: spirv.EXT.SpecConstantCompositeReplicate
+  spirv.SpecConstantComposite @scc_non_splat_vector_of_f32 (@sc_f32_1, @sc_f32_1, @sc_f32_2) : vector<3 x f32>
+
+  // CHECK-NOT: spirv.EXT.SpecConstantCompositeReplicate
+  spirv.SpecConstantComposite @scc_non_splat_arm_tensor_of_f32 (@sc_f32_2, @sc_f32_1, @sc_f32_1) : !spirv.arm.tensor<3xf32>
+
+  // CHECK-NOT: spirv.EXT.SpecConstantCompositeReplicate
+  spirv.SpecConstantComposite @scc_struct_of_i32_and_f32 (@sc_i32_1, @sc_i32_1, @sc_f32_1) : !spirv.struct<(i32, i32, f32)>
+}


        


More information about the Mlir-commits mailing list