[Mlir-commits] [mlir] [MLIR][Affine] Add functionality to demote invalid symbols to dims (PR #128289)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Feb 21 22:28:18 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-affine

Author: Uday Bondhugula (bondhugula)

<details>
<summary>Changes</summary>

Fixes: https://github.com/llvm/llvm-project/issues/120189

Introduce a method to demote a symoblic operand to a dimensional one
(the inverse of the current canonicalizePromotedSymbols).
Demote operands that could/should have been valid affine dimensional
values (affine loop IVs or their functions) from symbols to dims. This
is a general method that can be used to legalize a map + operands post
construction depending on its operands.

In several cases, affine.apply operands that could be dims for the
purpose of affine analysis were being used in symbolic positions
undetected. Fix the verifier so that affine dim-only SSA values aren't
used in symbolic positions; otherwise, it was possible for
`-canonicalize` to have generated invalid IR. This doesn't affect other
users in other context where the operands were neither valid dims or
symbols (for eg.  in scf.for or other region ops).

In some cases, this change also leads to better simplified operands,
duplicates eliminated as shown in one of the test cases where the same
operand appeared as a symbol and as a dim.

This PR also fixes test cases where dimensional positions should have
been ideally used with affine.apply (for affine loop IVs for example).
For several use cases, the IR builder/user wouldn't need to worry about
dim/symbols since canonicalizeMapOrSetAndOperands is updated to
transparently take care of it to ensure valid IR. Users outside of
affine analyses/dialects remain unaffected.


---
Full diff: https://github.com/llvm/llvm-project/pull/128289.diff


5 Files Affected:

- (modified) mlir/lib/Dialect/Affine/IR/AffineOps.cpp (+61-2) 
- (modified) mlir/test/Dialect/Affine/canonicalize.mlir (+2-2) 
- (modified) mlir/test/Dialect/Affine/invalid.mlir (+14) 
- (modified) mlir/test/Dialect/Affine/loop-fusion-4.mlir (+4-4) 
- (modified) mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir (+18-18) 


``````````diff
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 06493de2b09d4..609fb9ab6a3e7 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -568,6 +568,15 @@ LogicalResult AffineApplyOp::verify() {
   if (affineMap.getNumResults() != 1)
     return emitOpError("mapping must produce one value");
 
+  // Do not allow valid dims to be used in symbol positions. We do allow
+  // affine.apply to use operands for values that may neither qualify as affine
+  // dims or affine symbols due to usage outside of affine ops, analyses, etc.
+  Region *region = getAffineScope(*this);
+  for (Value operand : getMapOperands().drop_front(affineMap.getNumDims())) {
+    if (::isValidDim(operand, region) && !::isValidSymbol(operand, region))
+      return emitError("dimensional operand cannot be used as a symbol");
+  }
+
   return success();
 }
 
@@ -1351,13 +1360,62 @@ static void canonicalizePromotedSymbols(MapOrSet *mapOrSet,
 
   resultOperands.append(remappedSymbols.begin(), remappedSymbols.end());
   *operands = resultOperands;
-  *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, {}, nextDim,
-                                              oldNumSyms + nextSym);
+  *mapOrSet = mapOrSet->replaceDimsAndSymbols(
+      dimRemapping, /*symReplacements=*/{}, nextDim, oldNumSyms + nextSym);
 
   assert(mapOrSet->getNumInputs() == operands->size() &&
          "map/set inputs must match number of operands");
 }
 
+// A valid affine dimension may appear as a symbol in affine.apply operations.
+// This function canonicalizes symbols that are valid dims, but not valid
+// symbols into actual dims. Without such a legalization, the affine.apply will
+// be invalid. This method is the exact inverse of canonicalizePromotedSymbols.
+template <class MapOrSet>
+static void legalizeDemotedDims(MapOrSet *mapOrSet,
+                                SmallVectorImpl<Value> &operands) {
+  if (!mapOrSet || operands.empty())
+    return;
+
+  assert(mapOrSet->getNumInputs() == operands.size() &&
+         "map/set inputs must match number of operands");
+
+  auto *context = mapOrSet->getContext();
+  SmallVector<Value, 8> resultOperands;
+  resultOperands.reserve(operands.size());
+  SmallVector<Value, 8> remappedDims;
+  remappedDims.reserve(operands.size());
+  unsigned nextSym = 0;
+  unsigned nextDim = 0;
+  unsigned oldNumDims = mapOrSet->getNumDims();
+  SmallVector<AffineExpr, 8> symRemapping(mapOrSet->getNumSymbols());
+  for (unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) {
+    if (i >= oldNumDims) {
+      if (operands[i] && isValidDim(operands[i]) &&
+          !isValidSymbol(operands[i])) {
+        // This is a valid dim that appears as a symbol, legalize it.
+        symRemapping[i - oldNumDims] =
+            getAffineDimExpr(oldNumDims + nextDim++, context);
+        remappedDims.push_back(operands[i]);
+      } else {
+        symRemapping[i - oldNumDims] = getAffineSymbolExpr(nextSym++, context);
+        resultOperands.push_back(operands[i]);
+      }
+    } else {
+      resultOperands.push_back(operands[i]);
+    }
+  }
+
+  resultOperands.insert(resultOperands.begin() + oldNumDims,
+                        remappedDims.begin(), remappedDims.end());
+  operands = resultOperands;
+  *mapOrSet = mapOrSet->replaceDimsAndSymbols(
+      /*dimReplacements=*/{}, symRemapping, oldNumDims + nextDim, nextSym);
+
+  assert(mapOrSet->getNumInputs() == operands.size() &&
+         "map/set inputs must match number of operands");
+}
+
 // Works for either an affine map or an integer set.
 template <class MapOrSet>
 static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet,
@@ -1372,6 +1430,7 @@ static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet,
          "map/set inputs must match number of operands");
 
   canonicalizePromotedSymbols<MapOrSet>(mapOrSet, operands);
+  legalizeDemotedDims<MapOrSet>(mapOrSet, *operands);
 
   // Check to see what dims are used.
   llvm::SmallBitVector usedDims(mapOrSet->getNumDims());
diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index a9ac13ad71624..3a3114f8da63a 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -1460,8 +1460,8 @@ func.func @mod_of_mod(%lb: index, %ub: index, %step: index) -> (index, index) {
 func.func @prefetch_canonicalize(%arg0: memref<512xf32>) -> () {
   // CHECK: affine.for [[I_0_:%.+]] = 0 to 8 {
   affine.for %arg3 = 0 to 8  {
-    %1 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%arg3]
-    // CHECK: affine.prefetch [[PARAM_0_]][symbol([[I_0_]]) * 64], read, locality<3>, data : memref<512xf32>
+    %1 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg3)
+    // CHECK: affine.prefetch [[PARAM_0_]][[[I_0_]] * 64], read, locality<3>, data : memref<512xf32>
     affine.prefetch %arg0[%1], read, locality<3>, data : memref<512xf32>
   }
   return
diff --git a/mlir/test/Dialect/Affine/invalid.mlir b/mlir/test/Dialect/Affine/invalid.mlir
index 9bbd19c381163..9703c05fff8f6 100644
--- a/mlir/test/Dialect/Affine/invalid.mlir
+++ b/mlir/test/Dialect/Affine/invalid.mlir
@@ -563,3 +563,17 @@ func.func @no_upper_bound() {
   }
   return
 }
+
+// -----
+
+func.func @invalid_symbol() {
+  affine.for %arg1 = 0 to 1 {
+    affine.for %arg2 = 0 to 26 {
+      affine.for %arg3 = 0 to 23 {
+        affine.apply affine_map<()[s0, s1] -> (s0 * 23 + s1)>()[%arg1, %arg3]
+        // expected-error at above {{dimensional operand cannot be used as a symbol}}
+      }
+    }
+  }
+  return
+}
diff --git a/mlir/test/Dialect/Affine/loop-fusion-4.mlir b/mlir/test/Dialect/Affine/loop-fusion-4.mlir
index 9cc2dc1520623..a2a20b05a340f 100644
--- a/mlir/test/Dialect/Affine/loop-fusion-4.mlir
+++ b/mlir/test/Dialect/Affine/loop-fusion-4.mlir
@@ -360,8 +360,8 @@ func.func @same_memref_load_multiple_stores(%producer : memref<32xf32>, %produce
 
 // -----
 
-#map = affine_map<()[s0] -> (s0 + 5)>
-#map1 = affine_map<()[s0] -> (s0 + 17)>
+#map = affine_map<(d0) -> (d0 + 5)>
+#map1 = affine_map<(d0) -> (d0 + 17)>
 
 // Test with non-int/float memref types.
 
@@ -382,8 +382,8 @@ func.func @memref_index_type() {
   }
   affine.for %arg3 = 0 to 3 {
     %4 = affine.load %alloc_2[%arg3] : memref<3xindex>
-    %5 = affine.apply #map()[%4]
-    %6 = affine.apply #map1()[%3]
+    %5 = affine.apply #map(%4)
+    %6 = affine.apply #map1(%3)
     %7 = memref.load %alloc[%5, %6] : memref<8x18xf32>
     affine.store %7, %alloc_1[%arg3] : memref<3xf32>
   }
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index 327cacf7d9a20..62f58d5585de4 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -496,8 +496,8 @@ func.func @fold_dynamic_subview_with_memref_store_expand_shape(%arg0 : memref<16
 
 // -----
 
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 * 3)>
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (d0 * 3)>
 // CHECK-LABEL: fold_memref_alias_expand_shape_subview_load_store_dynamic_dim
 // CHECK-SAME: (%[[ARG0:.*]]: memref<2048x16xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index)
 func.func @fold_memref_alias_expand_shape_subview_load_store_dynamic_dim(%alloc: memref<2048x16xf32>, %c10: index, %c5: index, %c0: index, %sz0: index) {
@@ -518,16 +518,16 @@ func.func @fold_memref_alias_expand_shape_subview_load_store_dynamic_dim(%alloc:
 // CHECK-NEXT:   %[[DIM:.*]] = memref.dim %[[EXPAND_SHAPE]], %[[ARG3]] : memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
 // CHECK-NEXT:   affine.for %[[ARG4:.*]] = 0 to %[[DIM]] step 64 {
 // CHECK-NEXT:   affine.for %[[ARG5:.*]] = 0 to 16 step 16 {
-// CHECK-NEXT:   %[[VAL0:.*]] = affine.apply #[[$MAP0]]()[%[[ARG2]], %[[ARG4]]]
-// CHECK-NEXT:   %[[VAL1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG5]]]
+// CHECK-NEXT:   %[[VAL0:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])[%[[ARG2]]]
+// CHECK-NEXT:   %[[VAL1:.*]] = affine.apply #[[$MAP1]](%[[ARG5]])
 // CHECK-NEXT:   %[[VAL2:.*]] = affine.load %[[ARG0]][%[[VAL0]], %[[VAL1]]] : memref<2048x16xf32>
-// CHECK-NEXT:   %[[VAL3:.*]] = affine.apply #[[$MAP0]]()[%[[ARG2]], %[[ARG4]]]
+// CHECK-NEXT:   %[[VAL3:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])[%[[ARG2]]]
 // CHECK-NEXT:   affine.store %[[VAL2]], %[[ARG0]][%[[VAL3]], %[[ARG5]]] : memref<2048x16xf32>
 
 // -----
 
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (s0 * 1024 + s1)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0 * 1024 + d1)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
 // CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape
 // CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index)
 func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 {
@@ -549,14 +549,14 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0:
 // CHECK-NEXT:  affine.for %[[ARG4:.*]] = 0 to 1024 {
 // CHECK-NEXT:   affine.for %[[ARG5:.*]] = 0 to 1020 {
 // CHECK-NEXT:    affine.for %[[ARG6:.*]] = 0 to 1 {
-// CHECK-NEXT:     %[[IDX1:.*]] = affine.apply #[[$MAP0]]()[%[[ARG3]], %[[ARG4]]]
-// CHECK-NEXT:     %[[IDX2:.*]] = affine.apply #[[$MAP1]]()[%[[ARG5]], %[[ARG6]]]
+// CHECK-NEXT:     %[[IDX1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]], %[[ARG4]])
+// CHECK-NEXT:     %[[IDX2:.*]] = affine.apply #[[$MAP1]](%[[ARG5]], %[[ARG6]])
 // CHECK-NEXT:     affine.load %[[ARG0]][%[[IDX1]], %[[IDX2]]] : memref<1024x1024xf32>
 
 // -----
 
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1)[s0] -> (d0 + d1 + s0 * 1024)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0 * 1025 + d1)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
 // CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_when_access_index_is_an_expression
 // CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index)
 func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_when_access_index_is_an_expression(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 {
@@ -578,14 +578,14 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_when_a
 // CHECK-NEXT:  affine.for %[[ARG4:.*]] = 0 to 1024 {
 // CHECK-NEXT:   affine.for %[[ARG5:.*]] = 0 to 1020 {
 // CHECK-NEXT:    affine.for %[[ARG6:.*]] = 0 to 1 {
-// CHECK-NEXT:      %[[TMP1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]], %[[ARG4]])[%[[ARG3]]]
-// CHECK-NEXT:      %[[TMP3:.*]] = affine.apply #[[$MAP1]]()[%[[ARG5]], %[[ARG6]]]
+// CHECK-NEXT:      %[[TMP1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]], %[[ARG4]])
+// CHECK-NEXT:      %[[TMP3:.*]] = affine.apply #[[$MAP1]](%[[ARG5]], %[[ARG6]])
 // CHECK-NEXT:      affine.load %[[ARG0]][%[[TMP1]], %[[TMP3]]] : memref<1024x1024xf32>
 
 // -----
 
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (s0 * 1024)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0) -> (d0 * 1024)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
 // CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_with_constant_access_index
 // CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index)
 func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_with_constant_access_index(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 {
@@ -608,8 +608,8 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_with_c
 // CHECK-NEXT:   affine.for %[[ARG4:.*]] = 0 to 1024 {
 // CHECK-NEXT:    affine.for %[[ARG5:.*]] = 0 to 1020 {
 // CHECK-NEXT:     affine.for %[[ARG6:.*]] = 0 to 1 {
-// CHECK-NEXT:      %[[TMP1:.*]] = affine.apply #[[$MAP0]]()[%[[ARG3]]]
-// CHECK-NEXT:      %[[TMP2:.*]] = affine.apply #[[$MAP1]]()[%[[ARG5]], %[[ARG6]]]
+// CHECK-NEXT:      %[[TMP1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
+// CHECK-NEXT:      %[[TMP2:.*]] = affine.apply #[[$MAP1]](%[[ARG5]], %[[ARG6]])
 // CHECK-NEXT:      memref.load %[[ARG0]][%[[TMP1]], %[[TMP2]]] : memref<1024x1024xf32>
 
 // -----
@@ -678,7 +678,7 @@ func.func @fold_load_keep_nontemporal(%arg0 : memref<12x32xf32>, %arg1 : index,
 // -----
 
 // CHECK-LABEL: func @fold_store_keep_nontemporal(
-//      CHECK:   memref.store %{{.+}}, %{{.+}}[%{{.+}}, %{{.+}}]  {nontemporal = true} : memref<12x32xf32> 
+//      CHECK:   memref.store %{{.+}}, %{{.+}}[%{{.+}}, %{{.+}}]  {nontemporal = true} : memref<12x32xf32>
 func.func @fold_store_keep_nontemporal(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : f32) {
   %0 = memref.subview %arg0[%arg1, %arg2][4, 4][2, 3] :
     memref<12x32xf32> to memref<4x4xf32, strided<[64, 3], offset: ?>>

``````````

</details>


https://github.com/llvm/llvm-project/pull/128289


More information about the Mlir-commits mailing list