[Mlir-commits] [mlir] a3c7d46 - [mlir][spirv] Implement UMod canonicalization for vector constants (#141902)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jun 9 08:09:39 PDT 2025


Author: Darren Wihandi
Date: 2025-06-09T11:09:36-04:00
New Revision: a3c7d461456f2da25c1d119b6686773f675e313e

URL: https://github.com/llvm/llvm-project/commit/a3c7d461456f2da25c1d119b6686773f675e313e
DIFF: https://github.com/llvm/llvm-project/commit/a3c7d461456f2da25c1d119b6686773f675e313e.diff

LOG: [mlir][spirv] Implement UMod canonicalization for vector constants (#141902)

Closes #63174. 

Implements this transformation pattern, which is currently only applied
to scalars, for vectors:
```
%1 = "spirv.UMod"(%0, %CONST_32) : (i32, i32) -> i32
%2 = "spirv.UMod"(%1, %CONST_4) : (i32, i32) -> i32
```
to
```
%1 = "spirv.UMod"(%0, %CONST_32) : (i32, i32) -> i32
%2 = "spirv.UMod"(%0, %CONST_4) : (i32, i32) -> i32
```

Additionally fixes and issue where patterns like this:
```
%1 = "spirv.UMod"(%0, %CONST_4) : (i32, i32) -> i32
%2 = "spirv.UMod"(%1, %CONST_32) : (i32, i32) -> i32
```
were incorrectly canonicalized to:
```
%1 = "spirv.UMod"(%0, %CONST_4) : (i32, i32) -> i32
%2 = "spirv.UMod"(%0, %CONST_32) : (i32, i32) -> i32
```
which is incorrect since `(X % A) % B` == `(X % B)` IFF A is a multiple
of B, i.e., B divides A.

Added: 
    

Modified: 
    mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
    mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index e36d4b910193e..03af61c81ae6c 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -326,7 +326,6 @@ void spirv::UMulExtendedOp::getCanonicalizationPatterns(
 
 // The transformation is only applied if one divisor is a multiple of the other.
 
-// TODO(https://github.com/llvm/llvm-project/issues/63174): Add support for vector constants
 struct UModSimplification final : OpRewritePattern<spirv::UModOp> {
   using OpRewritePattern::OpRewritePattern;
 
@@ -336,19 +335,29 @@ struct UModSimplification final : OpRewritePattern<spirv::UModOp> {
     if (!prevUMod)
       return failure();
 
-    IntegerAttr prevValue;
-    IntegerAttr currValue;
+    TypedAttr prevValue;
+    TypedAttr currValue;
     if (!matchPattern(prevUMod.getOperand(1), m_Constant(&prevValue)) ||
         !matchPattern(umodOp.getOperand(1), m_Constant(&currValue)))
       return failure();
 
-    APInt prevConstValue = prevValue.getValue();
-    APInt currConstValue = currValue.getValue();
+    // Ensure that previous divisor is a multiple of the current divisor. If
+    // not, fail the transformation.
+    bool isApplicable = false;
+    if (auto prevInt = dyn_cast<IntegerAttr>(prevValue)) {
+      auto currInt = cast<IntegerAttr>(currValue);
+      isApplicable = prevInt.getValue().urem(currInt.getValue()) == 0;
+    } else if (auto prevVec = dyn_cast<DenseElementsAttr>(prevValue)) {
+      auto currVec = cast<DenseElementsAttr>(currValue);
+      isApplicable = llvm::all_of(llvm::zip_equal(prevVec.getValues<APInt>(),
+                                                  currVec.getValues<APInt>()),
+                                  [](const auto &pair) {
+                                    auto &[prev, curr] = pair;
+                                    return prev.urem(curr) == 0;
+                                  });
+    }
 
-    // Ensure that one divisor is a multiple of the other. If not, fail the
-    // transformation.
-    if (prevConstValue.urem(currConstValue) != 0 &&
-        currConstValue.urem(prevConstValue) != 0)
+    if (!isApplicable)
       return failure();
 
     // The transformation is safe. Replace the existing UMod operation with a

diff  --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index 0fd6c18a6c241..722c27586aa61 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -967,17 +967,17 @@ func.func @umod_fold(%arg0: i32) -> (i32, i32) {
   return %0, %1: i32, i32
 }
 
-// CHECK-LABEL: @umod_fail_vector_fold
+// CHECK-LABEL: @umod_vector_fold
 // CHECK-SAME: (%[[ARG:.*]]: vector<4xi32>)
-func.func @umod_fail_vector_fold(%arg0: vector<4xi32>) -> (vector<4xi32>, vector<4xi32>) {
+func.func @umod_vector_fold(%arg0: vector<4xi32>) -> (vector<4xi32>, vector<4xi32>) {
   // CHECK: %[[CONST4:.*]] = spirv.Constant dense<4> : vector<4xi32>
   // CHECK: %[[CONST32:.*]] = spirv.Constant dense<32> : vector<4xi32>
   %const1 = spirv.Constant dense<32> : vector<4xi32>
   %0 = spirv.UMod %arg0, %const1 : vector<4xi32>
-  // CHECK: %[[UMOD0:.*]] = spirv.UMod %[[ARG]], %[[CONST32]]
   %const2 = spirv.Constant dense<4> : vector<4xi32>
   %1 = spirv.UMod %0, %const2 : vector<4xi32>
-  // CHECK: %[[UMOD1:.*]] = spirv.UMod %[[UMOD0]], %[[CONST4]]
+  // CHECK: %[[UMOD0:.*]] = spirv.UMod %[[ARG]], %[[CONST32]]
+  // CHECK: %[[UMOD1:.*]] = spirv.UMod %[[ARG]], %[[CONST4]]
   // CHECK: return %[[UMOD0]], %[[UMOD1]]
   return %0, %1: vector<4xi32>, vector<4xi32>
 } 
@@ -996,9 +996,9 @@ func.func @umod_fold_same_divisor(%arg0: i32) -> (i32, i32) {
   return %0, %1: i32, i32
 }
 
-// CHECK-LABEL: @umod_fail_fold
+// CHECK-LABEL: @umod_fail_1_fold
 // CHECK-SAME: (%[[ARG:.*]]: i32)
-func.func @umod_fail_fold(%arg0: i32) -> (i32, i32) {
+func.func @umod_fail_1_fold(%arg0: i32) -> (i32, i32) {
   // CHECK: %[[CONST5:.*]] = spirv.Constant 5
   // CHECK: %[[CONST32:.*]] = spirv.Constant 32
   %const1 = spirv.Constant 32 : i32
@@ -1011,6 +1011,51 @@ func.func @umod_fail_fold(%arg0: i32) -> (i32, i32) {
   return %0, %1: i32, i32
 }
 
+// CHECK-LABEL: @umod_fail_2_fold
+// CHECK-SAME: (%[[ARG:.*]]: i32)
+func.func @umod_fail_2_fold(%arg0: i32) -> (i32, i32) {
+  // CHECK: %[[CONST32:.*]] = spirv.Constant 32
+  // CHECK: %[[CONST4:.*]] = spirv.Constant 4
+  %const1 = spirv.Constant 4 : i32
+  %0 = spirv.UMod %arg0, %const1 : i32
+  // CHECK: %[[UMOD0:.*]] = spirv.UMod %[[ARG]], %[[CONST4]]
+  %const2 = spirv.Constant 32 : i32
+  %1 = spirv.UMod %0, %const2 : i32
+  // CHECK: %[[UMOD1:.*]] = spirv.UMod %[[UMOD0]], %[[CONST32]]
+  // CHECK: return %[[UMOD0]], %[[UMOD1]]
+  return %0, %1: i32, i32
+}
+
+// CHECK-LABEL: @umod_vector_fail_1_fold
+// CHECK-SAME: (%[[ARG:.*]]: vector<4xi32>)
+func.func @umod_vector_fail_1_fold(%arg0: vector<4xi32>) -> (vector<4xi32>, vector<4xi32>) {
+  // CHECK: %[[CONST9:.*]] = spirv.Constant dense<9> : vector<4xi32>
+  // CHECK: %[[CONST64:.*]] = spirv.Constant dense<64> : vector<4xi32>
+  %const1 = spirv.Constant dense<64> : vector<4xi32>
+  %0 = spirv.UMod %arg0, %const1 : vector<4xi32>
+  // CHECK: %[[UMOD0:.*]] = spirv.UMod %[[ARG]], %[[CONST64]]
+  %const2 = spirv.Constant dense<9> : vector<4xi32>
+  %1 = spirv.UMod %0, %const2 : vector<4xi32>
+  // CHECK: %[[UMOD1:.*]] = spirv.UMod %[[UMOD0]], %[[CONST9]]
+  // CHECK: return %[[UMOD0]], %[[UMOD1]]
+  return %0, %1: vector<4xi32>, vector<4xi32>
+}
+
+// CHECK-LABEL: @umod_vector_fail_2_fold
+// CHECK-SAME: (%[[ARG:.*]]: vector<4xi32>)
+func.func @umod_vector_fail_2_fold(%arg0: vector<4xi32>) -> (vector<4xi32>, vector<4xi32>) {
+  // CHECK: %[[CONST32:.*]] = spirv.Constant dense<32> : vector<4xi32>
+  // CHECK: %[[CONST4:.*]] = spirv.Constant dense<4> : vector<4xi32>
+  %const1 = spirv.Constant dense<4> : vector<4xi32>
+  %0 = spirv.UMod %arg0, %const1 : vector<4xi32>
+  // CHECK: %[[UMOD0:.*]] = spirv.UMod %[[ARG]], %[[CONST4]]
+  %const2 = spirv.Constant dense<32> : vector<4xi32>
+  %1 = spirv.UMod %0, %const2 : vector<4xi32>
+  // CHECK: %[[UMOD1:.*]] = spirv.UMod %[[UMOD0]], %[[CONST32]]
+  // CHECK: return %[[UMOD0]], %[[UMOD1]]
+  return %0, %1: vector<4xi32>, vector<4xi32>
+}
+
 // -----
 
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list