[Mlir-commits] [mlir] 6d3ebd8 - [mlir][affine] Allow `memref.cast` in `isDimOpValidSymbol` (#74401)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Dec 13 15:54:43 PST 2023


Author: Matthias Springer
Date: 2023-12-14T08:54:39+09:00
New Revision: 6d3ebd831c31d473acb18511949d04038115864a

URL: https://github.com/llvm/llvm-project/commit/6d3ebd831c31d473acb18511949d04038115864a
DIFF: https://github.com/llvm/llvm-project/commit/6d3ebd831c31d473acb18511949d04038115864a.diff

LOG: [mlir][affine] Allow `memref.cast` in `isDimOpValidSymbol` (#74401)

`isDimOpValidSymbol` is used during the verification of `affine.for`
ops. It is used to check if LB/UB values are valid symbols. This change
adds support for `memref.cast`, which can be skipped over if it is a
ranked -> ranked cast.

This change fixes `mlir/test/Transforms/canonicalize.mlir`, which used
to fail when verifying the IR after each pattern application (#74270).
In this test case, a pattern that folds dynamic offsets/sizes/strides to
static ones is applied. This pattern inserts a trivial `memref.cast`
that can be folded away. This folding happens after the pattern
application, so the IR fails to verify after applying the
offsets/sizes/strides canonicalization pattern.

Note: The verifier of `affine.for` violates MLIR guidelines. Only local
properties of an op should be verified. The verifier should not inspect
the defining ops of operands. (This would mean that constraints such as
"operand is a valid affine symbol" cannot be verified.)

Added: 
    

Modified: 
    mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 7f2f3c3410c33b..d5be2e906989fa 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -354,8 +354,19 @@ static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region) {
   if (!index.has_value())
     return false;
 
+  // Skip over all memref.cast ops (if any).
+  Operation *op = dimOp.getShapedValue().getDefiningOp();
+  while (auto castOp = dyn_cast<memref::CastOp>(op)) {
+    // Bail on unranked memrefs.
+    if (isa<UnrankedMemRefType>(castOp.getSource().getType()))
+      return false;
+    op = castOp.getSource().getDefiningOp();
+    if (!op)
+      return false;
+  }
+
   int64_t i = index.value();
-  return TypeSwitch<Operation *, bool>(dimOp.getShapedValue().getDefiningOp())
+  return TypeSwitch<Operation *, bool>(op)
       .Case<memref::ViewOp, memref::SubViewOp, memref::AllocOp>(
           [&](auto op) { return isMemRefSizeValidSymbol(op, i, region); })
       .Default([](Operation *) { return false; });


        


More information about the Mlir-commits mailing list