[Mlir-commits] [mlir] [mlir][affine]fix create affine.for bug. (PR #117721)

lonely eagle llvmlistbot at llvm.org
Tue Nov 26 07:13:41 PST 2024


https://github.com/linuxlonelyeagle created https://github.com/llvm/llvm-project/pull/117721

I encountered this in the pass I wrote.
```
matmul.mlir:5:3: error: 'affine.for' op operand cannot be used as a symbol
  linalg.matmul
  ^
matmul.mlir:5:3: note: see current operation: 
"affine.for"(%9, %8) <{lowerBoundMap = affine_map<()[s0] -> (s0)>, operandSegmentSizes = array<i32: 1, 1, 0>, step = 32 : index, upperBoundMap = affine_map<()[s0] -> (s0)>}> ({
^bb0(%arg3: index):
  "affine.yield"() : () -> ()
}) : (index, index) -> ()
make: *** [makefile:56: gemm-opt-matmul-lower] Error 1
```
This is because affinemap is used in the lower-bound or upper-bound of create affine.for, and the symbol for affinemap comes from a memref.dim whose memref is a function argument, affine.for check will be failed.
Something like the following, but the code below doesn't make sense. What I'm trying to say is that I created such affine.for in pass encountered the above bug. but it's worth mentioning that if you write the following IR by hand, there is no problem. So I didn't add a test.
```
#map = affine_map<()[s0] -> (s0)>

func.func @func(%A : memref<32x128xf32>) {
  %0 = arith.constant 0 : index
  %1 = arith.constant 1 : index
  %dim_0 = memref.dim %A, %0 : memref<32x128xf32>
  %dim_1 = memref.dim %A, %1 : memref<32x128xf32>
  affine.for %it = #map()[%dim_0] to #map()[%dim_1] {

  }
  return
}
```



>From c069ddd70519fb8016d6b9d60ed80860552c721c Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Tue, 26 Nov 2024 23:02:13 +0800
Subject: [PATCH] fix create affine.for bug.

---
 mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 16 +++++++++++++---
 1 file changed, 13 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 1c5466730a5589..0d24e434328419 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -17,6 +17,7 @@
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
 #include "mlir/Interfaces/ShapedOpInterfaces.h"
 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
 #include "mlir/Transforms/InliningUtils.h"
@@ -352,9 +353,13 @@ static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region) {
 
   // Conservatively handle remaining BlockArguments as non-valid symbols.
   // E.g. scf.for iterArgs.
-  if (llvm::isa<BlockArgument>(dimOp.getShapedValue()))
-    return false;
-
+  if (auto blockArgument =
+          llvm::dyn_cast<BlockArgument>(dimOp.getShapedValue())) {
+    if (!llvm::isa<FunctionOpInterface>(
+            blockArgument.getParentRegion()->getParentOp())) {
+      return false;
+    }
+  }
   // The dim op is also okay if its operand memref is a view/subview whose
   // corresponding size is a valid symbol.
   std::optional<int64_t> index = getConstantIntValue(dimOp.getDimension());
@@ -365,6 +370,11 @@ static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region) {
 
   // Skip over all memref.cast ops (if any).
   Operation *op = dimOp.getShapedValue().getDefiningOp();
+
+  // the ShapedValue of the dim is the function block argument.
+  if (!op)
+    return true;
+
   while (auto castOp = dyn_cast<memref::CastOp>(op)) {
     // Bail on unranked memrefs.
     if (isa<UnrankedMemRefType>(castOp.getSource().getType()))



More information about the Mlir-commits mailing list