[Mlir-commits] [mlir] [mlir][Affine] take strides into account for contiguity check (PR #126579)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Feb 10 11:08:35 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: None (gdehame)

<details>
<summary>Changes</summary>

The isContiguousAccess utility function in LoopAnalysis.cpp didn't check for access strides. 
This patch adds another utility function which walks an affine expression to check that a given value acts as an offset in that expression. 
This new function is used to check whether a given affine access is of the shape IV + ... and therefore is contiguous taking into account strides. 
The access-analysis unit tests are also modified accordingly.

---
Full diff: https://github.com/llvm/llvm-project/pull/126579.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp (+65-2) 
- (modified) mlir/test/Dialect/Affine/access-analysis.mlir (-7) 


``````````diff
diff --git a/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp b/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp
index 411b5efb36cab91..73713a6ac75d462 100644
--- a/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp
@@ -18,6 +18,9 @@
 #include "mlir/Dialect/Affine/Analysis/NestedMatcher.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Visitors.h"
+#include "llvm/ADT/ArrayRef.h"
 #include "llvm/Support/MathExtras.h"
 
 #include "llvm/ADT/DenseSet.h"
@@ -190,7 +193,57 @@ DenseSet<Value> mlir::affine::getInvariantAccesses(Value iv,
   return res;
 }
 
-// TODO: check access stride.
+/// Check that x is an offset in resultExpr
+/// That is, check that the result is of the shape x + ...  
+static bool isOffset(AffineExpr resultExpr, int numDims, ArrayRef<Value> operands, Value offset) {
+  // Check if the expression is only the offset
+  if (isa<AffineDimExpr>(resultExpr))
+    return operands[cast<AffineDimExpr>(resultExpr).getPosition()] == offset;
+  if (isa<AffineSymbolExpr>(resultExpr))
+    return operands[cast<AffineSymbolExpr>(resultExpr).getPosition() + numDims] == offset;
+
+  // Otherwise, walk through the expression and check that it's of one of the shapes:
+  // - x + ...
+  // - (x + ...) mod ...
+  // The second pattern leads to piecewise contiguous accesses which can be considered contiguous
+  // for vectorization if the vectorization factor is a divisor of the modulo's left-hand-side
+  WalkResult walkRes = resultExpr.walk([&](AffineExpr expr) {
+    if (!isa<AffineBinaryOpExpr>(expr))
+      return WalkResult::skip();
+    if (expr.getKind() == AffineExprKind::Add) {
+      AffineExpr lhs = cast<AffineBinaryOpExpr>(expr).getLHS();
+      AffineExpr rhs = cast<AffineBinaryOpExpr>(expr).getRHS();
+      if (auto dimExpr = dyn_cast<AffineDimExpr>(lhs)) {
+        if (operands[dimExpr.getPosition()] == offset)
+          return WalkResult::interrupt();
+      }
+      else if (auto symExpr = dyn_cast<AffineSymbolExpr>(lhs))
+        if (operands[symExpr.getPosition() + numDims])
+          return WalkResult::interrupt();
+      if (auto dimExpr = dyn_cast<AffineDimExpr>(rhs)) {
+        if (operands[dimExpr.getPosition()] == offset)
+          return WalkResult::interrupt();
+      }
+      else if (auto symExpr = dyn_cast<AffineSymbolExpr>(rhs))
+        if (operands[symExpr.getPosition() + numDims])
+          return WalkResult::interrupt();
+      return WalkResult::advance();
+    }
+    if (expr.getKind() == AffineExprKind::Mod) {
+      AffineExpr lhs = cast<AffineBinaryOpExpr>(expr).getLHS();
+      if (auto dimExpr = dyn_cast<AffineDimExpr>(lhs)) {
+        if (operands[dimExpr.getPosition()] == offset)
+          return WalkResult::interrupt();
+      }
+      else if (auto symExpr = dyn_cast<AffineSymbolExpr>(lhs))
+        if (operands[symExpr.getPosition() + numDims])
+          return WalkResult::interrupt();
+    }
+    return WalkResult::skip();
+  });
+  return walkRes.wasInterrupted();
+}
+
 template <typename LoadOrStoreOp>
 bool mlir::affine::isContiguousAccess(Value iv, LoadOrStoreOp memoryOp,
                                       int *memRefDim) {
@@ -219,7 +272,17 @@ bool mlir::affine::isContiguousAccess(Value iv, LoadOrStoreOp memoryOp,
     });
     // Check access invariance of each operand in 'exprOperands'.
     for (Value exprOperand : exprOperands) {
-      if (!isAccessIndexInvariant(iv, exprOperand)) {
+      // Verify that the access is contiguous along the induction variable if it depends on it
+      // by checking that at most one of the op's access map's result is of the shape IV + constant
+      auto map = AffineMap::getMultiDimIdentityMap(/*numDims=*/1, iv.getContext());
+      SmallVector<Value> operands = {exprOperand};
+      AffineValueMap avm(map, operands);
+      avm.composeSimplifyAndCanonicalize();
+      if (avm.isFunctionOf(0, iv)) {
+        if (!isOffset(resultExpr, numDims, mapOperands, exprOperand) || 
+            !isOffset(avm.getResult(0), avm.getNumDims(), avm.getOperands(), iv)) {
+          return false;
+        }
         if (uniqueVaryingIndexAlongIv != -1) {
           // 2+ varying indices -> do not vectorize along iv.
           return false;
diff --git a/mlir/test/Dialect/Affine/access-analysis.mlir b/mlir/test/Dialect/Affine/access-analysis.mlir
index 789de646a8f9e2a..4f3532d3549c0f4 100644
--- a/mlir/test/Dialect/Affine/access-analysis.mlir
+++ b/mlir/test/Dialect/Affine/access-analysis.mlir
@@ -11,17 +11,12 @@ func.func @loop_simple(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
        // expected-remark at above {{invariant along loop 1}}
        affine.load %A[%c0, 8 * %i + %j] : memref<?x?xf32>
        // expected-remark at above {{contiguous along loop 1}}
-       // Note/FIXME: access stride isn't being checked.
-       // expected-remark at -3 {{contiguous along loop 0}}
 
        // These are all non-contiguous along both loops. Nothing is emitted.
        affine.load %A[%i, %c0] : memref<?x?xf32>
        // expected-remark at above {{invariant along loop 1}}
-       // Note/FIXME: access stride isn't being checked.
        affine.load %A[%i, 8 * %j] : memref<?x?xf32>
-       // expected-remark at above {{contiguous along loop 1}}
        affine.load %A[%j, 4 * %i] : memref<?x?xf32>
-       // expected-remark at above {{contiguous along loop 0}}
      }
    }
    return
@@ -70,7 +65,6 @@ func.func @tiled(%arg0: memref<*xf32>) {
             // expected-remark at above {{invariant along loop 4}}
             affine.store %0, %alloc_0[0, %arg1 * -16 + %arg4, 0, %arg3 * -16 + %arg5] : memref<1x16x1x16xf32>
             // expected-remark at above {{contiguous along loop 4}}
-            // expected-remark at above {{contiguous along loop 2}}
             // expected-remark at above {{invariant along loop 1}}
           }
         }
@@ -79,7 +73,6 @@ func.func @tiled(%arg0: memref<*xf32>) {
             affine.for %arg6 = #map(%arg3) to #map1(%arg3) {
               %0 = affine.load %alloc_0[0, %arg1 * -16 + %arg4, -%arg2 + %arg5, %arg3 * -16 + %arg6] : memref<1x16x1x16xf32>
               // expected-remark at above {{contiguous along loop 5}}
-              // expected-remark at above {{contiguous along loop 2}}
               affine.store %0, %alloc[0, %arg5, %arg6, %arg4] : memref<1x224x224x64xf32>
               // expected-remark at above {{contiguous along loop 3}}
               // expected-remark at above {{invariant along loop 0}}

``````````

</details>


https://github.com/llvm/llvm-project/pull/126579


More information about the Mlir-commits mailing list