[Mlir-commits] [mlir] [mlir][spirv] Implement UMod canonicalization for vector constants (PR #141902)
Darren Wihandi
llvmlistbot at llvm.org
Wed May 28 23:40:51 PDT 2025
https://github.com/fairywreath created https://github.com/llvm/llvm-project/pull/141902
None
>From ff5cd524e27d79ba5750c3bc9451751fb93a6984 Mon Sep 17 00:00:00 2001
From: fairywreath <nerradfour at gmail.com>
Date: Thu, 29 May 2025 02:13:25 -0400
Subject: [PATCH] [mlir][spirv] Implement UMod canonicalization for vector
constants
---
.../SPIRV/IR/SPIRVCanonicalization.cpp | 27 ++++++++++++-------
.../SPIRV/Transforms/canonicalize.mlir | 27 ++++++++++++++-----
2 files changed, 39 insertions(+), 15 deletions(-)
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index e36d4b910193e..89b46577f061c 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -326,7 +326,6 @@ void spirv::UMulExtendedOp::getCanonicalizationPatterns(
// The transformation is only applied if one divisor is a multiple of the other.
-// TODO(https://github.com/llvm/llvm-project/issues/63174): Add support for vector constants
struct UModSimplification final : OpRewritePattern<spirv::UModOp> {
using OpRewritePattern::OpRewritePattern;
@@ -336,19 +335,29 @@ struct UModSimplification final : OpRewritePattern<spirv::UModOp> {
if (!prevUMod)
return failure();
- IntegerAttr prevValue;
- IntegerAttr currValue;
+ TypedAttr prevValue;
+ TypedAttr currValue;
if (!matchPattern(prevUMod.getOperand(1), m_Constant(&prevValue)) ||
!matchPattern(umodOp.getOperand(1), m_Constant(&currValue)))
return failure();
- APInt prevConstValue = prevValue.getValue();
- APInt currConstValue = currValue.getValue();
+ // Ensure that previous divisor is a multiple of the current divisor. If
+ // not, fail the transformation.
+ bool isApplicable = false;
+ if (auto prevInt = dyn_cast<IntegerAttr>(prevValue)) {
+ auto currInt = dyn_cast<IntegerAttr>(currValue);
+ isApplicable = prevInt.getValue().urem(currInt.getValue()) == 0;
+ } else if (auto prevVec = dyn_cast<DenseElementsAttr>(prevValue)) {
+ auto currVec = dyn_cast<DenseElementsAttr>(currValue);
+ isApplicable = llvm::all_of(
+ llvm::zip(prevVec.getValues<APInt>(), currVec.getValues<APInt>()),
+ [](auto pair) {
+ const auto &[a, b] = pair;
+ return a.urem(b) == 0;
+ });
+ }
- // Ensure that one divisor is a multiple of the other. If not, fail the
- // transformation.
- if (prevConstValue.urem(currConstValue) != 0 &&
- currConstValue.urem(prevConstValue) != 0)
+ if (!isApplicable)
return failure();
// The transformation is safe. Replace the existing UMod operation with a
diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index 0fd6c18a6c241..52c915bfebc66 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -967,17 +967,17 @@ func.func @umod_fold(%arg0: i32) -> (i32, i32) {
return %0, %1: i32, i32
}
-// CHECK-LABEL: @umod_fail_vector_fold
+// CHECK-LABEL: @umod_vector_fold
// CHECK-SAME: (%[[ARG:.*]]: vector<4xi32>)
-func.func @umod_fail_vector_fold(%arg0: vector<4xi32>) -> (vector<4xi32>, vector<4xi32>) {
+func.func @umod_vector_fold(%arg0: vector<4xi32>) -> (vector<4xi32>, vector<4xi32>) {
// CHECK: %[[CONST4:.*]] = spirv.Constant dense<4> : vector<4xi32>
// CHECK: %[[CONST32:.*]] = spirv.Constant dense<32> : vector<4xi32>
%const1 = spirv.Constant dense<32> : vector<4xi32>
%0 = spirv.UMod %arg0, %const1 : vector<4xi32>
- // CHECK: %[[UMOD0:.*]] = spirv.UMod %[[ARG]], %[[CONST32]]
%const2 = spirv.Constant dense<4> : vector<4xi32>
%1 = spirv.UMod %0, %const2 : vector<4xi32>
- // CHECK: %[[UMOD1:.*]] = spirv.UMod %[[UMOD0]], %[[CONST4]]
+ // CHECK: %[[UMOD0:.*]] = spirv.UMod %[[ARG]], %[[CONST32]]
+ // CHECK: %[[UMOD1:.*]] = spirv.UMod %[[ARG]], %[[CONST4]]
// CHECK: return %[[UMOD0]], %[[UMOD1]]
return %0, %1: vector<4xi32>, vector<4xi32>
}
@@ -996,9 +996,9 @@ func.func @umod_fold_same_divisor(%arg0: i32) -> (i32, i32) {
return %0, %1: i32, i32
}
-// CHECK-LABEL: @umod_fail_fold
+// CHECK-LABEL: @umod_fail_1_fold
// CHECK-SAME: (%[[ARG:.*]]: i32)
-func.func @umod_fail_fold(%arg0: i32) -> (i32, i32) {
+func.func @umod_fail_1_fold(%arg0: i32) -> (i32, i32) {
// CHECK: %[[CONST5:.*]] = spirv.Constant 5
// CHECK: %[[CONST32:.*]] = spirv.Constant 32
%const1 = spirv.Constant 32 : i32
@@ -1011,6 +1011,21 @@ func.func @umod_fail_fold(%arg0: i32) -> (i32, i32) {
return %0, %1: i32, i32
}
+// CHECK-LABEL: @umod_fail_2_fold
+// CHECK-SAME: (%[[ARG:.*]]: vector<4xi32>)
+func.func @umod_fail_2_fold(%arg0: vector<4xi32>) -> (vector<4xi32>, vector<4xi32>) {
+ // CHECK: %[[CONST32:.*]] = spirv.Constant dense<32> : vector<4xi32>
+ // CHECK: %[[CONST4:.*]] = spirv.Constant dense<4> : vector<4xi32>
+ %const1 = spirv.Constant dense<4> : vector<4xi32>
+ %0 = spirv.UMod %arg0, %const1 : vector<4xi32>
+ // CHECK: %[[UMOD0:.*]] = spirv.UMod %[[ARG]], %[[CONST4]]
+ %const2 = spirv.Constant dense<32> : vector<4xi32>
+ %1 = spirv.UMod %0, %const2 : vector<4xi32>
+ // CHECK: %[[UMOD1:.*]] = spirv.UMod %[[UMOD0]], %[[CONST32]]
+ // CHECK: return %[[UMOD0]], %[[UMOD1]]
+ return %0, %1: vector<4xi32>, vector<4xi32>
+}
+
// -----
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list