[Mlir-commits] [mlir] 7c2d315 - [mlir][spirv] Don't return value when cannot fold spirv.bitcast

Lei Zhang llvmlistbot at llvm.org
Mon Nov 7 16:13:58 PST 2022


Author: Lei Zhang
Date: 2022-11-07T19:11:46-05:00
New Revision: 7c2d3153a9481793da58894dbf35d4994f3b67a4

URL: https://github.com/llvm/llvm-project/commit/7c2d3153a9481793da58894dbf35d4994f3b67a4
DIFF: https://github.com/llvm/llvm-project/commit/7c2d3153a9481793da58894dbf35d4994f3b67a4.diff

LOG: [mlir][spirv] Don't return value when cannot fold spirv.bitcast

Returing a value would make the canonicalization infrastructure
think that folding succeeded so the pattern will be tried again
when invoked via, e.g., `applyPatternsAndFoldGreedily` and
eventually fail due to not converging after 10 times by default.

Reviewed By: hanchung

Differential Revision: https://reviews.llvm.org/D137598

Added: 
    

Modified: 
    mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 57e6475548642..b068d23f0e9f0 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -117,22 +117,22 @@ void spirv::AccessChainOp::getCanonicalizationPatterns(
 //===----------------------------------------------------------------------===//
 
 OpFoldResult spirv::BitcastOp::fold(ArrayRef<Attribute> /*operands*/) {
-  Value arg = getOperand();
-  if (getType() == arg.getType())
-    return arg;
+  Value curInput = getOperand();
+  if (getType() == curInput.getType())
+    return curInput;
 
   // Look through nested bitcasts.
-  if (auto bitcast = arg.getDefiningOp<spirv::BitcastOp>()) {
-    Value nestedArg = bitcast.getOperand();
-    if (nestedArg.getType() == getType())
-      return nestedArg;
+  if (auto prevCast = curInput.getDefiningOp<spirv::BitcastOp>()) {
+    Value prevInput = prevCast.getOperand();
+    if (prevInput.getType() == getType())
+      return prevInput;
 
-    getOperandMutable().assign(nestedArg);
+    getOperandMutable().assign(prevInput);
     return getResult();
   }
 
   // TODO(kuhar): Consider constant-folding the operand attribute.
-  return getResult();
+  return {};
 }
 
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list