[Mlir-commits] [mlir] [mlir][affine]if the result of a Pure operation that whose operands are dimensional identifiers, then their results are dimensional identifiers. (PR #123542)

lonely eagle llvmlistbot at llvm.org
Sun Jan 19 18:49:26 PST 2025


https://github.com/linuxlonelyeagle updated https://github.com/llvm/llvm-project/pull/123542

>From 50fcac3e325d7e19094309481161d7a8a3b3bce4 Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Mon, 20 Jan 2025 10:38:35 +0800
Subject: [PATCH] if the result of a Pure operation that whose operands are
 dimensional identifiers,then their results are dimensional identifiers.

---
 mlir/docs/Dialects/Affine.md                  |   4 +-
 mlir/lib/Dialect/Affine/IR/AffineOps.cpp      |  17 +-
 mlir/test/Dialect/Affine/invalid.mlir         |  12 --
 .../Dialect/Affine/load-store-invalid.mlir    |  92 ---------
 mlir/test/Dialect/Affine/ops.mlir             | 174 ++++++++++++++++++
 5 files changed, 187 insertions(+), 112 deletions(-)

diff --git a/mlir/docs/Dialects/Affine.md b/mlir/docs/Dialects/Affine.md
index 0b6d7747e8a6f9..7306e5fbb37b74 100644
--- a/mlir/docs/Dialects/Affine.md
+++ b/mlir/docs/Dialects/Affine.md
@@ -83,8 +83,8 @@ location of the SSA use. Dimensions may be bound not only to anything that a
 symbol is bound to, but also to induction variables of enclosing
 [`affine.for`](#affinefor-mliraffineforop) and
 [`affine.parallel`](#affineparallel-mliraffineparallelop) operations, and the result
-of an [`affine.apply` operation](#affineapply-mliraffineapplyop) (which recursively
-may use other dimensions and symbols).
+of a `Pure` operation whose operands are valid dimensional identifiers.
+(which recursively may use other dimensions and symbols).
 
 ### Affine Expressions
 
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 147f5dd7a24b62..053f8e0bb4f2c8 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -274,7 +274,8 @@ Region *mlir::affine::getAffineScope(Operation *op) {
 // conditions:
 // *) It is valid as a symbol.
 // *) It is an induction variable.
-// *) It is the result of affine apply operation with dimension id arguments.
+// *) It is the result of a `Pure` operation whose operands are valid
+//    dimensional identifiers.
 bool mlir::affine::isValidDim(Value value) {
   // The value must be an index type.
   if (!value.getType().isIndex())
@@ -304,8 +305,8 @@ bool mlir::affine::isValidDim(Value value, Region *region) {
   if (isValidSymbol(value, region))
     return true;
 
-  auto *op = value.getDefiningOp();
-  if (!op) {
+  auto *defOp = value.getDefiningOp();
+  if (!defOp) {
     // This value has to be a block argument for an affine.for or an
     // affine.parallel.
     auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
@@ -313,11 +314,15 @@ bool mlir::affine::isValidDim(Value value, Region *region) {
   }
 
   // Affine apply operation is ok if all of its operands are ok.
-  if (auto applyOp = dyn_cast<AffineApplyOp>(op))
-    return applyOp.isValidDim(region);
+  if (isPure(defOp) && llvm::all_of(defOp->getOperands(), [&](Value operand) {
+        return affine::isValidDim(operand, region);
+      })) {
+    return true;
+  }
+
   // The dim op is okay if its operand memref/tensor is defined at the top
   // level.
-  if (auto dimOp = dyn_cast<ShapedDimOpInterface>(op))
+  if (auto dimOp = dyn_cast<ShapedDimOpInterface>(defOp))
     return isTopLevelValue(dimOp.getShapedValue());
   return false;
 }
diff --git a/mlir/test/Dialect/Affine/invalid.mlir b/mlir/test/Dialect/Affine/invalid.mlir
index 44e484b9ba5982..b3b2ec5552482a 100644
--- a/mlir/test/Dialect/Affine/invalid.mlir
+++ b/mlir/test/Dialect/Affine/invalid.mlir
@@ -225,18 +225,6 @@ func.func @affine_parallel(%arg0 : index, %arg1 : index, %arg2 : index) {
 
 // -----
 
-func.func @affine_parallel(%arg0 : index, %arg1 : index, %arg2 : index) {
-  affine.for %x = 0 to 7 {
-    %y = arith.addi %x, %x : index
-    // expected-error at +1 {{operand cannot be used as a dimension id}}
-    affine.parallel (%i, %j) = (0, 0) to (%y, 100) step (10, 10) {
-    }
-  }
-  return
-}
-
-// -----
-
 func.func @affine_parallel(%arg0 : index, %arg1 : index, %arg2 : index) {
   affine.for %x = 0 to 7 {
     %y = arith.addi %x, %x : index
diff --git a/mlir/test/Dialect/Affine/load-store-invalid.mlir b/mlir/test/Dialect/Affine/load-store-invalid.mlir
index 01d6b25dee695b..d8eac141cc52ae 100644
--- a/mlir/test/Dialect/Affine/load-store-invalid.mlir
+++ b/mlir/test/Dialect/Affine/load-store-invalid.mlir
@@ -33,31 +33,6 @@ func.func @store_too_few_subscripts_map(%arg0: memref<?x?xf32>, %arg1: index, %v
 
 // -----
 
-func.func @load_non_affine_index(%arg0 : index) {
-  %0 = memref.alloc() : memref<10xf32>
-  affine.for %i0 = 0 to 10 {
-    %1 = arith.muli %i0, %arg0 : index
-    // expected-error at +1 {{op index must be a valid dimension or symbol identifier}}
-    %v = affine.load %0[%1] : memref<10xf32>
-  }
-  return
-}
-
-// -----
-
-func.func @store_non_affine_index(%arg0 : index) {
-  %0 = memref.alloc() : memref<10xf32>
-  %1 = arith.constant 11.0 : f32
-  affine.for %i0 = 0 to 10 {
-    %2 = arith.muli %i0, %arg0 : index
-    // expected-error at +1 {{op index must be a valid dimension or symbol identifier}}
-    affine.store %1, %0[%2] : memref<10xf32>
-  }
-  return
-}
-
-// -----
-
 func.func @invalid_prefetch_rw(%i : index) {
   %0 = memref.alloc() : memref<10xf32>
   // expected-error at +1 {{rw specifier has to be 'read' or 'write'}}
@@ -73,70 +48,3 @@ func.func @invalid_prefetch_cache_type(%i : index) {
   affine.prefetch %0[%i], read, locality<0>, false  : memref<10xf32>
   return
 }
-
-// -----
-
-func.func @dma_start_non_affine_src_index(%arg0 : index) {
-  %0 = memref.alloc() : memref<100xf32>
-  %1 = memref.alloc() : memref<100xf32, 2>
-  %2 = memref.alloc() : memref<1xi32, 4>
-  %c0 = arith.constant 0 : index
-  %c64 = arith.constant 64 : index
-  affine.for %i0 = 0 to 10 {
-    %3 = arith.muli %i0, %arg0 : index
-    // expected-error at +1 {{op src index must be a valid dimension or symbol identifier}}
-    affine.dma_start %0[%3], %1[%i0], %2[%c0], %c64
-        : memref<100xf32>, memref<100xf32, 2>, memref<1xi32, 4>
-  }
-  return
-}
-
-// -----
-
-func.func @dma_start_non_affine_dst_index(%arg0 : index) {
-  %0 = memref.alloc() : memref<100xf32>
-  %1 = memref.alloc() : memref<100xf32, 2>
-  %2 = memref.alloc() : memref<1xi32, 4>
-  %c0 = arith.constant 0 : index
-  %c64 = arith.constant 64 : index
-  affine.for %i0 = 0 to 10 {
-    %3 = arith.muli %i0, %arg0 : index
-    // expected-error at +1 {{op dst index must be a valid dimension or symbol identifier}}
-    affine.dma_start %0[%i0], %1[%3], %2[%c0], %c64
-        : memref<100xf32>, memref<100xf32, 2>, memref<1xi32, 4>
-  }
-  return
-}
-
-// -----
-
-func.func @dma_start_non_affine_tag_index(%arg0 : index) {
-  %0 = memref.alloc() : memref<100xf32>
-  %1 = memref.alloc() : memref<100xf32, 2>
-  %2 = memref.alloc() : memref<1xi32, 4>
-  %c0 = arith.constant 0 : index
-  %c64 = arith.constant 64 : index
-  affine.for %i0 = 0 to 10 {
-    %3 = arith.muli %i0, %arg0 : index
-    // expected-error at +1 {{op tag index must be a valid dimension or symbol identifier}}
-    affine.dma_start %0[%i0], %1[%arg0], %2[%3], %c64
-        : memref<100xf32>, memref<100xf32, 2>, memref<1xi32, 4>
-  }
-  return
-}
-
-// -----
-
-func.func @dma_wait_non_affine_tag_index(%arg0 : index) {
-  %0 = memref.alloc() : memref<100xf32>
-  %1 = memref.alloc() : memref<100xf32, 2>
-  %2 = memref.alloc() : memref<1xi32, 4>
-  %c0 = arith.constant 0 : index
-  %c64 = arith.constant 64 : index
-  affine.for %i0 = 0 to 10 {
-    %3 = arith.muli %i0, %arg0 : index
-    // expected-error at +1 {{op index must be a valid dimension or symbol identifier}}
-    affine.dma_wait %2[%3], %c64 : memref<1xi32, 4>
-  }
-  return
-}
diff --git a/mlir/test/Dialect/Affine/ops.mlir b/mlir/test/Dialect/Affine/ops.mlir
index e3721806989bb9..74ba098ce27487 100644
--- a/mlir/test/Dialect/Affine/ops.mlir
+++ b/mlir/test/Dialect/Affine/ops.mlir
@@ -409,3 +409,177 @@ func.func @arith_add_vaild_symbol_lower_bound(%arg : index) {
 // CHECK:   affine.for %[[VAL_3:.*]] = #[[$ATTR_0]](%[[VAL_2]]){{\[}}%[[VAL_0]]] to 7 {
 // CHECK:   }
 // CHECK: }
+
+// -----
+
+// CHECK-LABEL: func @affine_parallel
+
+func.func @affine_parallel(%arg0 : index, %arg1 : index, %arg2 : index) {
+  affine.for %x = 0 to 7 {
+    %y = arith.addi %x, %x : index
+    affine.parallel (%i, %j) = (0, 0) to (%y, 100) step (10, 10) {
+    }
+  }
+  return
+}
+
+// CHECK-NEXT: affine.for 
+// CHECK-SAME: %[[VAL_0:.*]] = 0 to 7 {
+// CHECK: %[[VAL_1:.*]] = arith.addi %[[VAL_0]], %[[VAL_0]] : index
+// CHECK:   affine.parallel (%{{.*}}, %{{.*}}) = (0, 0) to (%[[VAL_1]], 100) step (10, 10) {
+// CHECK:   }
+// CHECK: }
+
+// -----
+
+func.func @load_non_affine_index(%arg0 : index) {
+  %0 = memref.alloc() : memref<10xf32>
+  affine.for %i0 = 0 to 10 {
+    %1 = arith.muli %i0, %arg0 : index
+    %v = affine.load %0[%1] : memref<10xf32>
+  }
+  return
+}
+
+// CHECK-LABEL: func @load_non_affine_index
+// CHECK-SAME:  %[[VAL_0:.*]]: index) {
+// CHECK: %[[VAL_1:.*]] = memref.alloc() : memref<10xf32>
+// CHECK: affine.for %[[VAL_2:.*]] = 0 to 10 {
+// CHECK:   %[[VAL_3:.*]] = arith.muli %[[VAL_2]], %[[VAL_0]] : index
+// CHECK:   %{{.*}} = affine.load %[[VAL_1]]{{\[}}%[[VAL_3]]] : memref<10xf32>
+// CHECK: }
+
+// -----
+
+func.func @store_non_affine_index(%arg0 : index) {
+  %0 = memref.alloc() : memref<10xf32>
+  %1 = arith.constant 11.0 : f32
+  affine.for %i0 = 0 to 10 {
+    %2 = arith.muli %i0, %arg0 : index
+    affine.store %1, %0[%2] : memref<10xf32>
+  }
+  return
+}
+
+// CHECK-LABEL: func @store_non_affine_index
+// CHECK-SAME: %[[VAL_0:.*]]: index) {
+// CHECK: %[[VAL_1:.*]] = memref.alloc() : memref<10xf32>
+// CHECK: %[[VAL_2:.*]] = arith.constant 1.100000e+01 : f32
+// CHECK: affine.for %[[VAL_3:.*]] = 0 to 10 {
+// CHECK:   %[[VAL_4:.*]] = arith.muli %[[VAL_3]], %[[VAL_0]] : index
+// CHECK:    affine.store %[[VAL_2]], %[[VAL_1]]{{\[}}%[[VAL_4]]] : memref<10xf32>
+// CHECK: }
+
+// -----
+
+func.func @dma_start_non_affine_src_index(%arg0 : index) {
+  %0 = memref.alloc() : memref<100xf32>
+  %1 = memref.alloc() : memref<100xf32, 2>
+  %2 = memref.alloc() : memref<1xi32, 4>
+  %c0 = arith.constant 0 : index
+  %c64 = arith.constant 64 : index
+  affine.for %i0 = 0 to 10 {
+    %3 = arith.muli %i0, %arg0 : index
+    affine.dma_start %0[%3], %1[%i0], %2[%c0], %c64
+        : memref<100xf32>, memref<100xf32, 2>, memref<1xi32, 4>
+  }
+  return
+}
+
+// CHECK-LABEL: func @dma_start_non_affine_src_index
+// CHECK-SAME: %[[VAL_0:.*]]: index) {
+// CHECK: %[[VAL_1:.*]] = memref.alloc() : memref<100xf32>
+// CHECK: %[[VAL_2:.*]] = memref.alloc() : memref<100xf32, 2>
+// CHECK: %[[VAL_3:.*]] = memref.alloc() : memref<1xi32, 4>
+// CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_5:.*]] = arith.constant 64 : index
+// CHECK: affine.for %[[VAL_6:.*]] = 0 to 10 {
+// CHECK:   %[[VAL_7:.*]] = arith.muli %[[VAL_6]], %[[VAL_0]] : index
+// CHECK:   affine.dma_start %[[VAL_1]]{{\[}}%[[VAL_7]]], %[[VAL_2]]{{\[}}%[[VAL_6]]], %[[VAL_3]]{{\[}}%[[VAL_4]]], %[[VAL_5]]
+// CHECK-SAME: : memref<100xf32>, memref<100xf32, 2>, memref<1xi32, 4>
+// CHECK: }
+
+// -----
+
+func.func @dma_start_non_affine_dst_index(%arg0 : index) {
+  %0 = memref.alloc() : memref<100xf32>
+  %1 = memref.alloc() : memref<100xf32, 2>
+  %2 = memref.alloc() : memref<1xi32, 4>
+  %c0 = arith.constant 0 : index
+  %c64 = arith.constant 64 : index
+  affine.for %i0 = 0 to 10 {
+    %3 = arith.muli %i0, %arg0 : index
+    affine.dma_start %0[%i0], %1[%3], %2[%c0], %c64
+        : memref<100xf32>, memref<100xf32, 2>, memref<1xi32, 4>
+  }
+  return
+}
+
+// CHECK-LABEL: func @dma_start_non_affine_dst_index
+// CHECK-SAME: %[[VAL_0:.*]]: index) {
+// CHECK: %[[VAL_1:.*]] = memref.alloc() : memref<100xf32>
+// CHECK: %[[VAL_2:.*]] = memref.alloc() : memref<100xf32, 2>
+// CHECK: %[[VAL_3:.*]] = memref.alloc() : memref<1xi32, 4>
+// CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_5:.*]] = arith.constant 64 : index
+// CHECK: affine.for %[[VAL_6:.*]] = 0 to 10 {
+// CHECK: %[[VAL_7:.*]] = arith.muli %[[VAL_6]], %[[VAL_0]] : index
+// CHECK: affine.dma_start %[[VAL_1]]{{\[}}%[[VAL_6]]], %[[VAL_2]]{{\[}}%[[VAL_7]]], %[[VAL_3]]{{\[}}%[[VAL_4]]], %[[VAL_5]]
+// CHECK-SAME: : memref<100xf32>, memref<100xf32, 2>, memref<1xi32, 4>
+// CHECK: }
+
+// -----
+
+func.func @dma_start_non_affine_tag_index(%arg0 : index) {
+  %0 = memref.alloc() : memref<100xf32>
+  %1 = memref.alloc() : memref<100xf32, 2>
+  %2 = memref.alloc() : memref<1xi32, 4>
+  %c0 = arith.constant 0 : index
+  %c64 = arith.constant 64 : index
+  affine.for %i0 = 0 to 10 {
+    %3 = arith.muli %i0, %arg0 : index
+    affine.dma_start %0[%i0], %1[%arg0], %2[%3], %c64
+        : memref<100xf32>, memref<100xf32, 2>, memref<1xi32, 4>
+  }
+  return
+}
+
+// CHECK-LABEL: func @dma_start_non_affine_tag_index
+// CHECK-SAME: %[[VAL_0:.*]]: index) {
+// CHECK: %[[VAL_1:.*]] = memref.alloc() : memref<100xf32>
+// CHECK: %[[VAL_2:.*]] = memref.alloc() : memref<100xf32, 2>
+// CHECK: %[[VAL_3:.*]] = memref.alloc() : memref<1xi32, 4>
+// CHECK: %{{.*}} = arith.constant 0 : index
+// CHECK: %[[VAL_4:.*]] = arith.constant 64 : index
+// CHECK: affine.for %[[VAL_5:.*]] = 0 to 10 {
+// CHECK: %[[VAL_6:.*]] = arith.muli %[[VAL_5]], %[[VAL_0]] : index
+// CHECK: affine.dma_start %[[VAL_1]]{{\[}}%[[VAL_5]]], %[[VAL_2]]{{\[}}%[[VAL_0]]], %[[VAL_3]]{{\[}}%[[VAL_6]]], %[[VAL_4]]
+// CHECK-SAME: : memref<100xf32>, memref<100xf32, 2>, memref<1xi32, 4>
+// CHECK: }
+
+// -----
+
+func.func @dma_wait_non_affine_tag_index(%arg0 : index) {
+  %0 = memref.alloc() : memref<100xf32>
+  %1 = memref.alloc() : memref<100xf32, 2>
+  %2 = memref.alloc() : memref<1xi32, 4>
+  %c0 = arith.constant 0 : index
+  %c64 = arith.constant 64 : index
+  affine.for %i0 = 0 to 10 {
+    %3 = arith.muli %i0, %arg0 : index
+    affine.dma_wait %2[%3], %c64 : memref<1xi32, 4>
+  }
+  return
+}
+
+// CHECK-LABEL: func @dma_wait_non_affine_tag_index
+// CHECK-SAME: %[[VAL_0:.*]]: index) {
+// CHECK: %{{.*}} = memref.alloc() : memref<100xf32>
+// CHECK: %{{.*}} = memref.alloc() : memref<100xf32, 2>
+// CHECK: %[[VAL_1:.*]] = memref.alloc() : memref<1xi32, 4>
+// CHECK: %{{.*}} = arith.constant 0 : index
+// CHECK: %[[VAL_2:.*]] = arith.constant 64 : index
+// CHECK: affine.for %[[VAL_3:.*]] = 0 to 10 {
+// CHECK:   %[[VAL_4:.*]] = arith.muli %[[VAL_3]], %[[VAL_0]] : index
+// CHECK:   affine.dma_wait %[[VAL_1]]{{\[}}%[[VAL_4]]], %[[VAL_2]] : memref<1xi32, 4>
+// CHECK: }



More information about the Mlir-commits mailing list