[Mlir-commits] [mlir] [mlir][spirv] Add conversion pass to rewrite splat constant composite… (PR #148910)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jul 15 10:49:17 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-spirv
Author: Mohammadreza Ameri Mahabadian (mahabadm)
<details>
<summary>Changes</summary>
…s to replicated form
This adds a new SPIR-V dialect-level conversion pass `ConversionToReplicatedConstantCompositePass`. This pass looks for splat composite `spirv.Constant` or `spirv.SpecConstantComposite` and rewrites them into `spirv.EXT.ConstantCompositeReplicate` or `spirv.EXT.SpecConstantCompositeReplicate`, respectively.
---
Full diff: https://github.com/llvm/llvm-project/pull/148910.diff
4 Files Affected:
- (modified) mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td (+7)
- (modified) mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt (+2)
- (added) mlir/lib/Dialect/SPIRV/Transforms/ConversionToReplicatedConstantCompositePass.cpp (+135)
- (added) mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir (+192)
``````````diff
diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td b/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td
index 2d9befe78001d..3c04db8396367 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td
@@ -77,4 +77,11 @@ def SPIRVWebGPUPreparePass : Pass<"spirv-webgpu-prepare", "spirv::ModuleOp"> {
"and replacing with supported ones";
}
+def SPIRVReplicatedConstantCompositePass
+ : Pass<"spirv-replicated-const-composite", "spirv::ModuleOp"> {
+ let summary = "Convert splat composite constants and spec constants to"
+ "corresponding replicated constant composite ops defined by"
+ "SPV_EXT_replicated_composites";
+}
+
#endif // MLIR_DIALECT_SPIRV_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
index 68e0206e30a59..c675af9d048cc 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
set(LLVM_OPTIONAL_SOURCES
CanonicalizeGLPass.cpp
+ ConversionToReplicatedConstantCompositePass.cpp
DecorateCompositeTypeLayoutPass.cpp
LowerABIAttributesPass.cpp
RewriteInsertsPass.cpp
@@ -30,6 +31,7 @@ add_mlir_dialect_library(MLIRSPIRVConversion
add_mlir_dialect_library(MLIRSPIRVTransforms
CanonicalizeGLPass.cpp
+ ConversionToReplicatedConstantCompositePass.cpp
DecorateCompositeTypeLayoutPass.cpp
LowerABIAttributesPass.cpp
RewriteInsertsPass.cpp
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/ConversionToReplicatedConstantCompositePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/ConversionToReplicatedConstantCompositePass.cpp
new file mode 100644
index 0000000000000..530a0f4aa67f5
--- /dev/null
+++ b/mlir/lib/Dialect/SPIRV/Transforms/ConversionToReplicatedConstantCompositePass.cpp
@@ -0,0 +1,135 @@
+//===- ConversionToReplicatedConstantCompositePass.cpp
+//---------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to convert a splat composite spirv.Constant and
+// spirv.SpecConstantComposite to spirv.EXT.ConstantCompositeReplicate and
+// spirv.EXT.SpecConstantCompositeReplicate respectively.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace spirv {
+#define GEN_PASS_DEF_SPIRVREPLICATEDCONSTANTCOMPOSITEPASS
+#include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
+} // namespace spirv
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+Attribute getSplatAttribute(Attribute valueAttr, uint32_t splatCount) {
+ Attribute attr;
+ if (auto denseAttr = dyn_cast<DenseElementsAttr>(valueAttr)) {
+ if (denseAttr.isSplat()) {
+ attr = denseAttr.getSplatValue<Attribute>();
+ splatCount = denseAttr.size();
+ }
+ } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
+ if (std::adjacent_find(arrayAttr.begin(), arrayAttr.end(),
+ std::not_equal_to<>()) == arrayAttr.end()) {
+ attr = arrayAttr[0];
+ splatCount = arrayAttr.size();
+ }
+ }
+
+ if (attr) {
+ if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
+ if (isa<spirv::CompositeType>(typedAttr.getType()))
+ if (Attribute newAttr = getSplatAttribute(attr, splatCount))
+ attr = newAttr;
+ } else if (isa<ArrayAttr>(attr)) {
+ if (Attribute newAttr = getSplatAttribute(attr, splatCount))
+ attr = newAttr;
+ }
+ }
+
+ return attr;
+}
+
+} // namespace
+
+namespace {
+class ConversionToReplicatedConstantCompositePass
+ : public spirv::impl::SPIRVReplicatedConstantCompositePassBase<
+ ConversionToReplicatedConstantCompositePass> {
+public:
+ void runOnOperation() override;
+};
+
+class ConstantOpConversion : public OpRewritePattern<spirv::ConstantOp> {
+ using OpRewritePattern<spirv::ConstantOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(spirv::ConstantOp op,
+ PatternRewriter &rewriter) const override {
+ auto compositeType = dyn_cast_or_null<spirv::CompositeType>(op.getType());
+ if (!compositeType)
+ return rewriter.notifyMatchFailure(op, "not a composite constant");
+
+ uint32_t splatCount = 0;
+ Attribute splatAttr = getSplatAttribute(op.getValue(), splatCount);
+ if (!splatAttr)
+ return rewriter.notifyMatchFailure(op, "composite is not splat");
+
+ if (splatCount == 1)
+ return rewriter.notifyMatchFailure(op,
+ "composite has only one consituent");
+
+ rewriter.replaceOpWithNewOp<spirv::EXTConstantCompositeReplicateOp>(
+ op, op.getType(), splatAttr);
+
+ return success();
+ }
+};
+
+class SpecConstantCompositeOpConversion
+ : public OpRewritePattern<spirv::SpecConstantCompositeOp> {
+ using OpRewritePattern<spirv::SpecConstantCompositeOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(spirv::SpecConstantCompositeOp op,
+ PatternRewriter &rewriter) const override {
+ auto compositeType = dyn_cast_or_null<spirv::CompositeType>(op.getType());
+ if (!compositeType)
+ return rewriter.notifyMatchFailure(op, "not a composite constant");
+
+ auto constituents = op.getConstituents();
+ if (constituents.size() == 1)
+ return rewriter.notifyMatchFailure(op,
+ "composite has only one consituent");
+
+ if (!(std::adjacent_find(constituents.begin(), constituents.end(),
+ std::not_equal_to<>()) == constituents.end()))
+ return rewriter.notifyMatchFailure(op, "composite is not splat");
+
+ auto splatConstituent =
+ dyn_cast<FlatSymbolRefAttr>(op.getConstituents()[0]);
+
+ rewriter.replaceOpWithNewOp<spirv::EXTSpecConstantCompositeReplicateOp>(
+ op, TypeAttr::get(op.getType()), op.getSymNameAttr(), splatConstituent);
+
+ return success();
+ }
+};
+
+void ConversionToReplicatedConstantCompositePass::runOnOperation() {
+ MLIRContext *context = &getContext();
+ RewritePatternSet patterns(context);
+ patterns.add<ConstantOpConversion>(context);
+ patterns.add<SpecConstantCompositeOpConversion>(context);
+
+ if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
+ signalPassFailure();
+ }
+}
+
+} // namespace
\ No newline at end of file
diff --git a/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir b/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir
new file mode 100644
index 0000000000000..f8cd4bb256bfe
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir
@@ -0,0 +1,192 @@
+// RUN: mlir-opt -spirv-replicated-const-composite -split-input-file -verify-diagnostics %s -o - | FileCheck %s
+
+spirv.module Logical GLSL450 {
+ spirv.func @splat_vector_of_i32() -> (vector<3xi32>) "None" {
+ %0 = spirv.Constant dense<2> : vector<3xi32>
+ // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2 : i32] : vector<3xi32>
+ spirv.ReturnValue %0 : vector<3xi32>
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+ spirv.func @splat_array_of_i32() -> (!spirv.array<3 x i32>) "None" {
+ %0 = spirv.Constant [1 : i32, 1 : i32, 1 : i32] : !spirv.array<3 x i32>
+ // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [1 : i32] : !spirv.array<3 x i32>
+ spirv.ReturnValue %0 : !spirv.array<3 x i32>
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+ spirv.func @splat_array_of_splat_array_of_i32() -> (!spirv.array<2 x !spirv.array<3 x i32>>) "None" {
+ // CHECK: spirv.EXT.ConstantCompositeReplicate [3 : i32] : !spirv.array<2 x !spirv.array<3 x i32>>
+ %0 = spirv.Constant [[3 : i32, 3 : i32, 3 : i32], [3 : i32, 3 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<3 x i32>>
+ spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x i32>>
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+ spirv.func @splat_array_of_non_splat_array_of_i32() -> (!spirv.array<2 x !spirv.array<3 x i32>>) "None" {
+ // CHECK: spirv.EXT.ConstantCompositeReplicate {{\[}}[1 : i32, 2 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<3 x i32>>
+ %0 = spirv.Constant [[1 : i32, 2 : i32, 3 : i32], [1 : i32, 2 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<3 x i32>>
+ spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x i32>>
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+ spirv.func @splat_array_of_vectors_of_i32() -> (!spirv.array<2xvector<2xi32>>) "None" {
+ %0 = spirv.Constant [dense<[1, 2]> : vector<2xi32>, dense<[1, 2]> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>>
+ // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<[1, 2]> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>>
+ spirv.ReturnValue %0 : !spirv.array<2 x vector<2xi32>>
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+ spirv.func @splat_array_of_splat_vectors_of_i32() -> (!spirv.array<2 x vector<2xi32>>) "None" {
+ %0 = spirv.Constant [dense<2> : vector<2xi32>, dense<2> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>>
+ // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2 : i32] : !spirv.array<2 x vector<2xi32>>
+ spirv.ReturnValue %0 : !spirv.array<2 x vector<2xi32>>
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+ spirv.func @splat_tensor_of_i32() -> (!spirv.array<2 x !spirv.array<3 x i32>>) "None" {
+ // CHECK: spirv.EXT.ConstantCompositeReplicate [3 : i32] : !spirv.array<2 x !spirv.array<3 x i32>>
+ %0 = spirv.Constant dense<3> : tensor<2x3xi32> : !spirv.array<2 x !spirv.array<3 x i32>>
+ spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x i32>>
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+ spirv.func @splat_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
+ // CHECK: spirv.EXT.ConstantCompositeReplicate [2 : i32] : !spirv.arm.tensor<2x3xi32>
+ %0 = spirv.Constant dense<2> : !spirv.arm.tensor<2x3xi32>
+ spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+ spirv.func @splat_vector_of_f32() -> (vector<3xf32>) "None" {
+ %0 = spirv.Constant dense<2.0> : vector<3xf32>
+ // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2.000000e+00 : f32] : vector<3xf32>
+ spirv.ReturnValue %0 : vector<3xf32>
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+ spirv.func @splat_array_of_f32() -> (!spirv.array<3 x f32>) "None" {
+ %0 = spirv.Constant [1.0 : f32, 1.0 : f32, 1.0 : f32] : !spirv.array<3 x f32>
+ // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [1.000000e+00 : f32] : !spirv.array<3 x f32>
+ spirv.ReturnValue %0 : !spirv.array<3 x f32>
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+ spirv.func @splat_array_of_splat_array_of_f32() -> (!spirv.array<2 x !spirv.array<3 x f32>>) "None" {
+ // CHECK: spirv.EXT.ConstantCompositeReplicate [3.000000e+00 : f32] : !spirv.array<2 x !spirv.array<3 x f32>>
+ %0 = spirv.Constant [[3.0 : f32, 3.0 : f32, 3.0 : f32], [3.0 : f32, 3.0 : f32, 3.0 : f32]] : !spirv.array<2 x !spirv.array<3 x f32>>
+ spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x f32>>
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+ spirv.func @splat_array_of_non_splat_array_of_f32() -> (!spirv.array<2 x !spirv.array<3 x f32>>) "None" {
+ // CHECK: spirv.EXT.ConstantCompositeReplicate {{\[}}[1.000000e+00 : f32, 2.000000e+00 : f32, 3.000000e+00 : f32]] : !spirv.array<2 x !spirv.array<3 x f32>>
+ %0 = spirv.Constant [[1.0 : f32, 2.0 : f32, 3.0 : f32], [1.0 : f32, 2.0 : f32, 3.0 : f32]] : !spirv.array<2 x !spirv.array<3 x f32>>
+ spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x f32>>
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+ spirv.func @splat_array_of_vectors_of_f32() -> (!spirv.array<2xvector<2xf32>>) "None" {
+ %0 = spirv.Constant [dense<[1.0, 2.0]> : vector<2xf32>, dense<[1.0, 2.0]> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>>
+ // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<[1.000000e+00, 2.000000e+00]> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>>
+ spirv.ReturnValue %0 : !spirv.array<2 x vector<2xf32>>
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+ spirv.func @splat_array_of_splat_vectors_of_f32() -> (!spirv.array<2 x vector<2xf32>>) "None" {
+ %0 = spirv.Constant [dense<2.0> : vector<2xf32>, dense<2.0> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>>
+ // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2.000000e+00 : f32] : !spirv.array<2 x vector<2xf32>>
+ spirv.ReturnValue %0 : !spirv.array<2 x vector<2xf32>>
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+ spirv.func @splat_tensor_of_f32() -> (!spirv.array<2 x !spirv.array<3 x f32>>) "None" {
+ // CHECK: spirv.EXT.ConstantCompositeReplicate [3.000000e+00 : f32] : !spirv.array<2 x !spirv.array<3 x f32>>
+ %0 = spirv.Constant dense<3.0> : tensor<2x3xf32> : !spirv.array<2 x !spirv.array<3 x f32>>
+ spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x f32>>
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+ spirv.func @splat_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
+ // CHECK: spirv.EXT.ConstantCompositeReplicate [2.000000e+00 : f32] : !spirv.arm.tensor<2x3xf32>
+ %0 = spirv.Constant dense<2.0> : !spirv.arm.tensor<2x3xf32>
+ spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompositesEXT], [SPV_EXT_replicated_composites]> {
+
+ spirv.SpecConstant @sc_i32_1 = 1 : i32
+
+ // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_array_of_i32 (@sc_i32_1) : !spirv.array<3 x i32>
+ spirv.SpecConstantComposite @scc_splat_array_of_i32 (@sc_i32_1, @sc_i32_1, @sc_i32_1) : !spirv.array<3 x i32>
+
+ // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_struct_of_i32 (@sc_i32_1) : !spirv.struct<(i32, i32, i32)>
+ spirv.SpecConstantComposite @scc_splat_struct_of_i32 (@sc_i32_1, @sc_i32_1, @sc_i32_1) : !spirv.struct<(i32, i32, i32)>
+
+ // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_vector_of_i32 (@sc_i32_1) : vector<3xi32>
+ spirv.SpecConstantComposite @scc_splat_vector_of_i32 (@sc_i32_1, @sc_i32_1, @sc_i32_1) : vector<3 x i32>
+
+ // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_arm_tensor_of_i32 (@sc_i32_1) : !spirv.arm.tensor<3xi32>
+ spirv.SpecConstantComposite @scc_splat_arm_tensor_of_i32 (@sc_i32_1, @sc_i32_1, @sc_i32_1) : !spirv.arm.tensor<3xi32>
+
+ spirv.SpecConstant @sc_f32_1 = 1.0 : f32
+
+ // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_array_of_f32 (@sc_f32_1) : !spirv.array<3 x f32>
+ spirv.SpecConstantComposite @scc_splat_array_of_f32 (@sc_f32_1, @sc_f32_1, @sc_f32_1) : !spirv.array<3 x f32>
+
+ // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_struct_of_f32 (@sc_f32_1) : !spirv.struct<(f32, f32, f32)>
+ spirv.SpecConstantComposite @scc_splat_struct_of_f32 (@sc_f32_1, @sc_f32_1, @sc_f32_1) : !spirv.struct<(f32, f32, f32)>
+
+ // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_vector_of_f32 (@sc_f32_1) : vector<3xf32>
+ spirv.SpecConstantComposite @scc_splat_vector_of_f32 (@sc_f32_1, @sc_f32_1, @sc_f32_1) : vector<3 x f32>
+
+ // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_arm_tensor_of_f32 (@sc_f32_1) : !spirv.arm.tensor<3xf32>
+ spirv.SpecConstantComposite @scc_splat_arm_tensor_of_f32 (@sc_f32_1, @sc_f32_1, @sc_f32_1) : !spirv.arm.tensor<3xf32>
+}
\ No newline at end of file
``````````
</details>
https://github.com/llvm/llvm-project/pull/148910
More information about the Mlir-commits
mailing list