[Mlir-commits] [mlir] 0e3426f - [mlir][spirv] Add a canonicalization pattern for UModOp

Jakub Kuderski llvmlistbot at llvm.org
Thu Jun 8 07:34:14 PDT 2023


Author: Nishant Patel
Date: 2023-06-08T10:32:01-04:00
New Revision: 0e3426f3fe27223ab6a1a67e8756df2214c258e3

URL: https://github.com/llvm/llvm-project/commit/0e3426f3fe27223ab6a1a67e8756df2214c258e3
DIFF: https://github.com/llvm/llvm-project/commit/0e3426f3fe27223ab6a1a67e8756df2214c258e3.diff

LOG: [mlir][spirv] Add a canonicalization pattern for UModOp

Add a transformation for a pattern like

%6 = "spirv.Constant"() <{value = 32 : i32}> : () -> i32
%7 = "spirv.UMod"(%5, %6) : (i32, i32) -> i32
%8 = "spirv.Constant"() <{value = 4 : i32}> : () -> i32
%9 = "spirv.UMod"(%7, %8) : (i32, i32) -> i32

to transform to

%6 = "spirv.Constant"() <{value = 32 : i32}> : () -> i32
%7 = "spirv.UMod"(%5, %6) : (i32, i32) -> i32
%8 = "spirv.Constant"() <{value = 4 : i32}> : () -> i32
%9= "spirv.UMod"(%5, %8) : (i32, i32) -> i32

Patch By: nbpatel
Reviewed By: antiagainst, kuhar

Differential Revision: https://reviews.llvm.org/D152341

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
    mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
    mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index f18796b7c96ec..c4d1e01f9feef 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -810,6 +810,8 @@ def SPIRV_UModOp : SPIRV_ArithmeticBinaryOp<"UMod",
 
     ```
   }];
+
+  let hasCanonicalizer = 1;
 }
 
 #endif // MLIR_DIALECT_SPIRV_IR_ARITHMETIC_OPS

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 3ada160444dd4..def62b4467ce3 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -114,6 +114,58 @@ void spirv::AccessChainOp::getCanonicalizationPatterns(
   results.add<CombineChainedAccessChain>(context);
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.UMod
+//===----------------------------------------------------------------------===//
+
+// Input:
+//    %0 = spirv.UMod %arg0, %const32 : i32
+//    %1 = spirv.UMod %0, %const4 : i32
+// Output:
+//    %0 = spirv.UMod %arg0, %const32 : i32
+//    %1 = spirv.UMod %arg0, %const4 : i32
+
+// The transformation is only applied if one divisor is a multiple of the other.
+
+// TODO(https://github.com/llvm/llvm-project/issues/63174): Add support for vector constants
+struct UModSimplification final : OpRewritePattern<spirv::UModOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(spirv::UModOp umodOp,
+                                PatternRewriter &rewriter) const override {
+    auto prevUMod = umodOp.getOperand(0).getDefiningOp<spirv::UModOp>();
+    if (!prevUMod)
+      return failure();
+
+    IntegerAttr prevValue;
+    IntegerAttr currValue;
+    if (!matchPattern(prevUMod.getOperand(1), m_Constant(&prevValue)) ||
+        !matchPattern(umodOp.getOperand(1), m_Constant(&currValue)))
+      return failure();
+
+    APInt prevConstValue = prevValue.getValue();
+    APInt currConstValue = currValue.getValue();
+
+    // Ensure that one divisor is a multiple of the other. If not, fail the
+    // transformation.
+    if (prevConstValue.urem(currConstValue) != 0 &&
+        currConstValue.urem(prevConstValue) != 0)
+      return failure();
+
+    // The transformation is safe. Replace the existing UMod operation with a
+    // new UMod operation, using the original dividend and the current divisor.
+    rewriter.replaceOpWithNewOp<spirv::UModOp>(
+        umodOp, umodOp.getType(), prevUMod.getOperand(0), umodOp.getOperand(1));
+
+    return success();
+  }
+};
+
+void spirv::UModOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                                MLIRContext *context) {
+  patterns.insert<UModSimplification>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.BitcastOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index d5dc97a6245b1..52607d7267852 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -453,6 +453,71 @@ func.func @const_fold_vector_isub() -> vector<3xi32> {
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// spirv.UMod
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @umod_fold
+// CHECK-SAME: (%[[ARG:.*]]: i32)
+func.func @umod_fold(%arg0: i32) -> (i32, i32) {
+  // CHECK: %[[CONST4:.*]] = spirv.Constant 4
+  // CHECK: %[[CONST32:.*]] = spirv.Constant 32
+  %const1 = spirv.Constant 32 : i32
+  %0 = spirv.UMod %arg0, %const1 : i32
+  %const2 = spirv.Constant 4 : i32
+  %1 = spirv.UMod %0, %const2 : i32
+  // CHECK: %[[UMOD0:.*]] = spirv.UMod %[[ARG]], %[[CONST32]]
+  // CHECK: %[[UMOD1:.*]] = spirv.UMod %[[ARG]], %[[CONST4]]
+  // CHECK: return %[[UMOD0]], %[[UMOD1]]
+  return %0, %1: i32, i32
+}
+
+// CHECK-LABEL: @umod_fail_vector_fold
+// CHECK-SAME: (%[[ARG:.*]]: vector<4xi32>)
+func.func @umod_fail_vector_fold(%arg0: vector<4xi32>) -> (vector<4xi32>, vector<4xi32>) {
+  // CHECK: %[[CONST4:.*]] = spirv.Constant dense<4> : vector<4xi32>
+  // CHECK: %[[CONST32:.*]] = spirv.Constant dense<32> : vector<4xi32>
+  %const1 = spirv.Constant dense<32> : vector<4xi32>
+  %0 = spirv.UMod %arg0, %const1 : vector<4xi32>
+  // CHECK: %[[UMOD0:.*]] = spirv.UMod %[[ARG]], %[[CONST32]]
+  %const2 = spirv.Constant dense<4> : vector<4xi32>
+  %1 = spirv.UMod %0, %const2 : vector<4xi32>
+  // CHECK: %[[UMOD1:.*]] = spirv.UMod %[[UMOD0]], %[[CONST4]]
+  // CHECK: return %[[UMOD0]], %[[UMOD1]]
+  return %0, %1: vector<4xi32>, vector<4xi32>
+} 
+
+// CHECK-LABEL: @umod_fold_same_divisor
+// CHECK-SAME: (%[[ARG:.*]]: i32)
+func.func @umod_fold_same_divisor(%arg0: i32) -> (i32, i32) {
+  // CHECK: %[[CONST1:.*]] = spirv.Constant 32
+  %const1 = spirv.Constant 32 : i32
+  %0 = spirv.UMod %arg0, %const1 : i32
+  %const2 = spirv.Constant 32 : i32
+  %1 = spirv.UMod %0, %const2 : i32
+  // CHECK: %[[UMOD0:.*]] = spirv.UMod %[[ARG]], %[[CONST1]]
+  // CHECK: %[[UMOD1:.*]] = spirv.UMod %[[ARG]], %[[CONST1]]
+  // CHECK: return %[[UMOD0]], %[[UMOD1]]
+  return %0, %1: i32, i32
+}
+
+// CHECK-LABEL: @umod_fail_fold
+// CHECK-SAME: (%[[ARG:.*]]: i32)
+func.func @umod_fail_fold(%arg0: i32) -> (i32, i32) {
+  // CHECK: %[[CONST5:.*]] = spirv.Constant 5
+  // CHECK: %[[CONST32:.*]] = spirv.Constant 32
+  %const1 = spirv.Constant 32 : i32
+  %0 = spirv.UMod %arg0, %const1 : i32
+  // CHECK: %[[UMOD0:.*]] = spirv.UMod %[[ARG]], %[[CONST32]]
+  %const2 = spirv.Constant 5 : i32
+  %1 = spirv.UMod %0, %const2 : i32
+  // CHECK: %[[UMOD1:.*]] = spirv.UMod %[[UMOD0]], %[[CONST5]]
+  // CHECK: return %[[UMOD0]], %[[UMOD1]]
+  return %0, %1: i32, i32
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.LogicalAnd
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list