[Mlir-commits] [mlir] [mlir][affine]fix create affine.for bug. (PR #117721)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Nov 26 07:14:24 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-affine

@llvm/pr-subscribers-mlir

Author: lonely eagle (linuxlonelyeagle)

<details>
<summary>Changes</summary>

I encountered this in the pass I wrote.
```
matmul.mlir:5:3: error: 'affine.for' op operand cannot be used as a symbol
  linalg.matmul
  ^
matmul.mlir:5:3: note: see current operation: 
"affine.for"(%9, %8) <{lowerBoundMap = affine_map<()[s0] -> (s0)>, operandSegmentSizes = array<i32: 1, 1, 0>, step = 32 : index, upperBoundMap = affine_map<()[s0] -> (s0)>}> ({
^bb0(%arg3: index):
  "affine.yield"() : () -> ()
}) : (index, index) -> ()
make: *** [makefile:56: gemm-opt-matmul-lower] Error 1
```
This is because affinemap is used in the lower-bound or upper-bound of create affine.for, and the symbol for affinemap comes from a memref.dim whose memref is a function argument, affine.for check will be failed.
Something like the following, but the code below doesn't make sense. What I'm trying to say is that I created such affine.for in pass encountered the above bug. but it's worth mentioning that if you write the following IR by hand, there is no problem. So I didn't add a test.
```
#map = affine_map<()[s0] -> (s0)>

func.func @<!-- -->func(%A : memref<32x128xf32>) {
  %0 = arith.constant 0 : index
  %1 = arith.constant 1 : index
  %dim_0 = memref.dim %A, %0 : memref<32x128xf32>
  %dim_1 = memref.dim %A, %1 : memref<32x128xf32>
  affine.for %it = #map()[%dim_0] to #map()[%dim_1] {

  }
  return
}
```



---
Full diff: https://github.com/llvm/llvm-project/pull/117721.diff


1 Files Affected:

- (modified) mlir/lib/Dialect/Affine/IR/AffineOps.cpp (+13-3) 


``````````diff
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 1c5466730a5589..0d24e434328419 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -17,6 +17,7 @@
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
 #include "mlir/Interfaces/ShapedOpInterfaces.h"
 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
 #include "mlir/Transforms/InliningUtils.h"
@@ -352,9 +353,13 @@ static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region) {
 
   // Conservatively handle remaining BlockArguments as non-valid symbols.
   // E.g. scf.for iterArgs.
-  if (llvm::isa<BlockArgument>(dimOp.getShapedValue()))
-    return false;
-
+  if (auto blockArgument =
+          llvm::dyn_cast<BlockArgument>(dimOp.getShapedValue())) {
+    if (!llvm::isa<FunctionOpInterface>(
+            blockArgument.getParentRegion()->getParentOp())) {
+      return false;
+    }
+  }
   // The dim op is also okay if its operand memref is a view/subview whose
   // corresponding size is a valid symbol.
   std::optional<int64_t> index = getConstantIntValue(dimOp.getDimension());
@@ -365,6 +370,11 @@ static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region) {
 
   // Skip over all memref.cast ops (if any).
   Operation *op = dimOp.getShapedValue().getDefiningOp();
+
+  // the ShapedValue of the dim is the function block argument.
+  if (!op)
+    return true;
+
   while (auto castOp = dyn_cast<memref::CastOp>(op)) {
     // Bail on unranked memrefs.
     if (isa<UnrankedMemRefType>(castOp.getSource().getType()))

``````````

</details>


https://github.com/llvm/llvm-project/pull/117721


More information about the Mlir-commits mailing list