[Mlir-commits] [mlir] [mlir][affine]if the result of a Pure operation that whose operands are dimensional identifiers, then their results are dimensional identifiers. (PR #123542)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Jan 19 18:48:47 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-affine
Author: lonely eagle (linuxlonelyeagle)
<details>
<summary>Changes</summary>
as title.
see the comment https://github.com/llvm/llvm-project/pull/118478#issuecomment-2592921000.
---
Full diff: https://github.com/llvm/llvm-project/pull/123542.diff
5 Files Affected:
- (modified) mlir/docs/Dialects/Affine.md (+2-3)
- (modified) mlir/lib/Dialect/Affine/IR/AffineOps.cpp (+11-6)
- (modified) mlir/test/Dialect/Affine/invalid.mlir (-12)
- (modified) mlir/test/Dialect/Affine/load-store-invalid.mlir (-92)
- (modified) mlir/test/Dialect/Affine/ops.mlir (+174)
``````````diff
diff --git a/mlir/docs/Dialects/Affine.md b/mlir/docs/Dialects/Affine.md
index 0b6d7747e8a6f9..94f23af699ca46 100644
--- a/mlir/docs/Dialects/Affine.md
+++ b/mlir/docs/Dialects/Affine.md
@@ -83,9 +83,8 @@ location of the SSA use. Dimensions may be bound not only to anything that a
symbol is bound to, but also to induction variables of enclosing
[`affine.for`](#affinefor-mliraffineforop) and
[`affine.parallel`](#affineparallel-mliraffineparallelop) operations, and the result
-of an [`affine.apply` operation](#affineapply-mliraffineapplyop) (which recursively
-may use other dimensions and symbols).
-
+of a `Pure` operation whose operands are valid dimensional identifiers.
+(which recursively may use other dimensions and symbols).
### Affine Expressions
Syntax:
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 147f5dd7a24b62..053f8e0bb4f2c8 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -274,7 +274,8 @@ Region *mlir::affine::getAffineScope(Operation *op) {
// conditions:
// *) It is valid as a symbol.
// *) It is an induction variable.
-// *) It is the result of affine apply operation with dimension id arguments.
+// *) It is the result of a `Pure` operation whose operands are valid
+// dimensional identifiers.
bool mlir::affine::isValidDim(Value value) {
// The value must be an index type.
if (!value.getType().isIndex())
@@ -304,8 +305,8 @@ bool mlir::affine::isValidDim(Value value, Region *region) {
if (isValidSymbol(value, region))
return true;
- auto *op = value.getDefiningOp();
- if (!op) {
+ auto *defOp = value.getDefiningOp();
+ if (!defOp) {
// This value has to be a block argument for an affine.for or an
// affine.parallel.
auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
@@ -313,11 +314,15 @@ bool mlir::affine::isValidDim(Value value, Region *region) {
}
// Affine apply operation is ok if all of its operands are ok.
- if (auto applyOp = dyn_cast<AffineApplyOp>(op))
- return applyOp.isValidDim(region);
+ if (isPure(defOp) && llvm::all_of(defOp->getOperands(), [&](Value operand) {
+ return affine::isValidDim(operand, region);
+ })) {
+ return true;
+ }
+
// The dim op is okay if its operand memref/tensor is defined at the top
// level.
- if (auto dimOp = dyn_cast<ShapedDimOpInterface>(op))
+ if (auto dimOp = dyn_cast<ShapedDimOpInterface>(defOp))
return isTopLevelValue(dimOp.getShapedValue());
return false;
}
diff --git a/mlir/test/Dialect/Affine/invalid.mlir b/mlir/test/Dialect/Affine/invalid.mlir
index 44e484b9ba5982..b3b2ec5552482a 100644
--- a/mlir/test/Dialect/Affine/invalid.mlir
+++ b/mlir/test/Dialect/Affine/invalid.mlir
@@ -225,18 +225,6 @@ func.func @affine_parallel(%arg0 : index, %arg1 : index, %arg2 : index) {
// -----
-func.func @affine_parallel(%arg0 : index, %arg1 : index, %arg2 : index) {
- affine.for %x = 0 to 7 {
- %y = arith.addi %x, %x : index
- // expected-error at +1 {{operand cannot be used as a dimension id}}
- affine.parallel (%i, %j) = (0, 0) to (%y, 100) step (10, 10) {
- }
- }
- return
-}
-
-// -----
-
func.func @affine_parallel(%arg0 : index, %arg1 : index, %arg2 : index) {
affine.for %x = 0 to 7 {
%y = arith.addi %x, %x : index
diff --git a/mlir/test/Dialect/Affine/load-store-invalid.mlir b/mlir/test/Dialect/Affine/load-store-invalid.mlir
index 01d6b25dee695b..d8eac141cc52ae 100644
--- a/mlir/test/Dialect/Affine/load-store-invalid.mlir
+++ b/mlir/test/Dialect/Affine/load-store-invalid.mlir
@@ -33,31 +33,6 @@ func.func @store_too_few_subscripts_map(%arg0: memref<?x?xf32>, %arg1: index, %v
// -----
-func.func @load_non_affine_index(%arg0 : index) {
- %0 = memref.alloc() : memref<10xf32>
- affine.for %i0 = 0 to 10 {
- %1 = arith.muli %i0, %arg0 : index
- // expected-error at +1 {{op index must be a valid dimension or symbol identifier}}
- %v = affine.load %0[%1] : memref<10xf32>
- }
- return
-}
-
-// -----
-
-func.func @store_non_affine_index(%arg0 : index) {
- %0 = memref.alloc() : memref<10xf32>
- %1 = arith.constant 11.0 : f32
- affine.for %i0 = 0 to 10 {
- %2 = arith.muli %i0, %arg0 : index
- // expected-error at +1 {{op index must be a valid dimension or symbol identifier}}
- affine.store %1, %0[%2] : memref<10xf32>
- }
- return
-}
-
-// -----
-
func.func @invalid_prefetch_rw(%i : index) {
%0 = memref.alloc() : memref<10xf32>
// expected-error at +1 {{rw specifier has to be 'read' or 'write'}}
@@ -73,70 +48,3 @@ func.func @invalid_prefetch_cache_type(%i : index) {
affine.prefetch %0[%i], read, locality<0>, false : memref<10xf32>
return
}
-
-// -----
-
-func.func @dma_start_non_affine_src_index(%arg0 : index) {
- %0 = memref.alloc() : memref<100xf32>
- %1 = memref.alloc() : memref<100xf32, 2>
- %2 = memref.alloc() : memref<1xi32, 4>
- %c0 = arith.constant 0 : index
- %c64 = arith.constant 64 : index
- affine.for %i0 = 0 to 10 {
- %3 = arith.muli %i0, %arg0 : index
- // expected-error at +1 {{op src index must be a valid dimension or symbol identifier}}
- affine.dma_start %0[%3], %1[%i0], %2[%c0], %c64
- : memref<100xf32>, memref<100xf32, 2>, memref<1xi32, 4>
- }
- return
-}
-
-// -----
-
-func.func @dma_start_non_affine_dst_index(%arg0 : index) {
- %0 = memref.alloc() : memref<100xf32>
- %1 = memref.alloc() : memref<100xf32, 2>
- %2 = memref.alloc() : memref<1xi32, 4>
- %c0 = arith.constant 0 : index
- %c64 = arith.constant 64 : index
- affine.for %i0 = 0 to 10 {
- %3 = arith.muli %i0, %arg0 : index
- // expected-error at +1 {{op dst index must be a valid dimension or symbol identifier}}
- affine.dma_start %0[%i0], %1[%3], %2[%c0], %c64
- : memref<100xf32>, memref<100xf32, 2>, memref<1xi32, 4>
- }
- return
-}
-
-// -----
-
-func.func @dma_start_non_affine_tag_index(%arg0 : index) {
- %0 = memref.alloc() : memref<100xf32>
- %1 = memref.alloc() : memref<100xf32, 2>
- %2 = memref.alloc() : memref<1xi32, 4>
- %c0 = arith.constant 0 : index
- %c64 = arith.constant 64 : index
- affine.for %i0 = 0 to 10 {
- %3 = arith.muli %i0, %arg0 : index
- // expected-error at +1 {{op tag index must be a valid dimension or symbol identifier}}
- affine.dma_start %0[%i0], %1[%arg0], %2[%3], %c64
- : memref<100xf32>, memref<100xf32, 2>, memref<1xi32, 4>
- }
- return
-}
-
-// -----
-
-func.func @dma_wait_non_affine_tag_index(%arg0 : index) {
- %0 = memref.alloc() : memref<100xf32>
- %1 = memref.alloc() : memref<100xf32, 2>
- %2 = memref.alloc() : memref<1xi32, 4>
- %c0 = arith.constant 0 : index
- %c64 = arith.constant 64 : index
- affine.for %i0 = 0 to 10 {
- %3 = arith.muli %i0, %arg0 : index
- // expected-error at +1 {{op index must be a valid dimension or symbol identifier}}
- affine.dma_wait %2[%3], %c64 : memref<1xi32, 4>
- }
- return
-}
diff --git a/mlir/test/Dialect/Affine/ops.mlir b/mlir/test/Dialect/Affine/ops.mlir
index e3721806989bb9..74ba098ce27487 100644
--- a/mlir/test/Dialect/Affine/ops.mlir
+++ b/mlir/test/Dialect/Affine/ops.mlir
@@ -409,3 +409,177 @@ func.func @arith_add_vaild_symbol_lower_bound(%arg : index) {
// CHECK: affine.for %[[VAL_3:.*]] = #[[$ATTR_0]](%[[VAL_2]]){{\[}}%[[VAL_0]]] to 7 {
// CHECK: }
// CHECK: }
+
+// -----
+
+// CHECK-LABEL: func @affine_parallel
+
+func.func @affine_parallel(%arg0 : index, %arg1 : index, %arg2 : index) {
+ affine.for %x = 0 to 7 {
+ %y = arith.addi %x, %x : index
+ affine.parallel (%i, %j) = (0, 0) to (%y, 100) step (10, 10) {
+ }
+ }
+ return
+}
+
+// CHECK-NEXT: affine.for
+// CHECK-SAME: %[[VAL_0:.*]] = 0 to 7 {
+// CHECK: %[[VAL_1:.*]] = arith.addi %[[VAL_0]], %[[VAL_0]] : index
+// CHECK: affine.parallel (%{{.*}}, %{{.*}}) = (0, 0) to (%[[VAL_1]], 100) step (10, 10) {
+// CHECK: }
+// CHECK: }
+
+// -----
+
+func.func @load_non_affine_index(%arg0 : index) {
+ %0 = memref.alloc() : memref<10xf32>
+ affine.for %i0 = 0 to 10 {
+ %1 = arith.muli %i0, %arg0 : index
+ %v = affine.load %0[%1] : memref<10xf32>
+ }
+ return
+}
+
+// CHECK-LABEL: func @load_non_affine_index
+// CHECK-SAME: %[[VAL_0:.*]]: index) {
+// CHECK: %[[VAL_1:.*]] = memref.alloc() : memref<10xf32>
+// CHECK: affine.for %[[VAL_2:.*]] = 0 to 10 {
+// CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_2]], %[[VAL_0]] : index
+// CHECK: %{{.*}} = affine.load %[[VAL_1]]{{\[}}%[[VAL_3]]] : memref<10xf32>
+// CHECK: }
+
+// -----
+
+func.func @store_non_affine_index(%arg0 : index) {
+ %0 = memref.alloc() : memref<10xf32>
+ %1 = arith.constant 11.0 : f32
+ affine.for %i0 = 0 to 10 {
+ %2 = arith.muli %i0, %arg0 : index
+ affine.store %1, %0[%2] : memref<10xf32>
+ }
+ return
+}
+
+// CHECK-LABEL: func @store_non_affine_index
+// CHECK-SAME: %[[VAL_0:.*]]: index) {
+// CHECK: %[[VAL_1:.*]] = memref.alloc() : memref<10xf32>
+// CHECK: %[[VAL_2:.*]] = arith.constant 1.100000e+01 : f32
+// CHECK: affine.for %[[VAL_3:.*]] = 0 to 10 {
+// CHECK: %[[VAL_4:.*]] = arith.muli %[[VAL_3]], %[[VAL_0]] : index
+// CHECK: affine.store %[[VAL_2]], %[[VAL_1]]{{\[}}%[[VAL_4]]] : memref<10xf32>
+// CHECK: }
+
+// -----
+
+func.func @dma_start_non_affine_src_index(%arg0 : index) {
+ %0 = memref.alloc() : memref<100xf32>
+ %1 = memref.alloc() : memref<100xf32, 2>
+ %2 = memref.alloc() : memref<1xi32, 4>
+ %c0 = arith.constant 0 : index
+ %c64 = arith.constant 64 : index
+ affine.for %i0 = 0 to 10 {
+ %3 = arith.muli %i0, %arg0 : index
+ affine.dma_start %0[%3], %1[%i0], %2[%c0], %c64
+ : memref<100xf32>, memref<100xf32, 2>, memref<1xi32, 4>
+ }
+ return
+}
+
+// CHECK-LABEL: func @dma_start_non_affine_src_index
+// CHECK-SAME: %[[VAL_0:.*]]: index) {
+// CHECK: %[[VAL_1:.*]] = memref.alloc() : memref<100xf32>
+// CHECK: %[[VAL_2:.*]] = memref.alloc() : memref<100xf32, 2>
+// CHECK: %[[VAL_3:.*]] = memref.alloc() : memref<1xi32, 4>
+// CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_5:.*]] = arith.constant 64 : index
+// CHECK: affine.for %[[VAL_6:.*]] = 0 to 10 {
+// CHECK: %[[VAL_7:.*]] = arith.muli %[[VAL_6]], %[[VAL_0]] : index
+// CHECK: affine.dma_start %[[VAL_1]]{{\[}}%[[VAL_7]]], %[[VAL_2]]{{\[}}%[[VAL_6]]], %[[VAL_3]]{{\[}}%[[VAL_4]]], %[[VAL_5]]
+// CHECK-SAME: : memref<100xf32>, memref<100xf32, 2>, memref<1xi32, 4>
+// CHECK: }
+
+// -----
+
+func.func @dma_start_non_affine_dst_index(%arg0 : index) {
+ %0 = memref.alloc() : memref<100xf32>
+ %1 = memref.alloc() : memref<100xf32, 2>
+ %2 = memref.alloc() : memref<1xi32, 4>
+ %c0 = arith.constant 0 : index
+ %c64 = arith.constant 64 : index
+ affine.for %i0 = 0 to 10 {
+ %3 = arith.muli %i0, %arg0 : index
+ affine.dma_start %0[%i0], %1[%3], %2[%c0], %c64
+ : memref<100xf32>, memref<100xf32, 2>, memref<1xi32, 4>
+ }
+ return
+}
+
+// CHECK-LABEL: func @dma_start_non_affine_dst_index
+// CHECK-SAME: %[[VAL_0:.*]]: index) {
+// CHECK: %[[VAL_1:.*]] = memref.alloc() : memref<100xf32>
+// CHECK: %[[VAL_2:.*]] = memref.alloc() : memref<100xf32, 2>
+// CHECK: %[[VAL_3:.*]] = memref.alloc() : memref<1xi32, 4>
+// CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_5:.*]] = arith.constant 64 : index
+// CHECK: affine.for %[[VAL_6:.*]] = 0 to 10 {
+// CHECK: %[[VAL_7:.*]] = arith.muli %[[VAL_6]], %[[VAL_0]] : index
+// CHECK: affine.dma_start %[[VAL_1]]{{\[}}%[[VAL_6]]], %[[VAL_2]]{{\[}}%[[VAL_7]]], %[[VAL_3]]{{\[}}%[[VAL_4]]], %[[VAL_5]]
+// CHECK-SAME: : memref<100xf32>, memref<100xf32, 2>, memref<1xi32, 4>
+// CHECK: }
+
+// -----
+
+func.func @dma_start_non_affine_tag_index(%arg0 : index) {
+ %0 = memref.alloc() : memref<100xf32>
+ %1 = memref.alloc() : memref<100xf32, 2>
+ %2 = memref.alloc() : memref<1xi32, 4>
+ %c0 = arith.constant 0 : index
+ %c64 = arith.constant 64 : index
+ affine.for %i0 = 0 to 10 {
+ %3 = arith.muli %i0, %arg0 : index
+ affine.dma_start %0[%i0], %1[%arg0], %2[%3], %c64
+ : memref<100xf32>, memref<100xf32, 2>, memref<1xi32, 4>
+ }
+ return
+}
+
+// CHECK-LABEL: func @dma_start_non_affine_tag_index
+// CHECK-SAME: %[[VAL_0:.*]]: index) {
+// CHECK: %[[VAL_1:.*]] = memref.alloc() : memref<100xf32>
+// CHECK: %[[VAL_2:.*]] = memref.alloc() : memref<100xf32, 2>
+// CHECK: %[[VAL_3:.*]] = memref.alloc() : memref<1xi32, 4>
+// CHECK: %{{.*}} = arith.constant 0 : index
+// CHECK: %[[VAL_4:.*]] = arith.constant 64 : index
+// CHECK: affine.for %[[VAL_5:.*]] = 0 to 10 {
+// CHECK: %[[VAL_6:.*]] = arith.muli %[[VAL_5]], %[[VAL_0]] : index
+// CHECK: affine.dma_start %[[VAL_1]]{{\[}}%[[VAL_5]]], %[[VAL_2]]{{\[}}%[[VAL_0]]], %[[VAL_3]]{{\[}}%[[VAL_6]]], %[[VAL_4]]
+// CHECK-SAME: : memref<100xf32>, memref<100xf32, 2>, memref<1xi32, 4>
+// CHECK: }
+
+// -----
+
+func.func @dma_wait_non_affine_tag_index(%arg0 : index) {
+ %0 = memref.alloc() : memref<100xf32>
+ %1 = memref.alloc() : memref<100xf32, 2>
+ %2 = memref.alloc() : memref<1xi32, 4>
+ %c0 = arith.constant 0 : index
+ %c64 = arith.constant 64 : index
+ affine.for %i0 = 0 to 10 {
+ %3 = arith.muli %i0, %arg0 : index
+ affine.dma_wait %2[%3], %c64 : memref<1xi32, 4>
+ }
+ return
+}
+
+// CHECK-LABEL: func @dma_wait_non_affine_tag_index
+// CHECK-SAME: %[[VAL_0:.*]]: index) {
+// CHECK: %{{.*}} = memref.alloc() : memref<100xf32>
+// CHECK: %{{.*}} = memref.alloc() : memref<100xf32, 2>
+// CHECK: %[[VAL_1:.*]] = memref.alloc() : memref<1xi32, 4>
+// CHECK: %{{.*}} = arith.constant 0 : index
+// CHECK: %[[VAL_2:.*]] = arith.constant 64 : index
+// CHECK: affine.for %[[VAL_3:.*]] = 0 to 10 {
+// CHECK: %[[VAL_4:.*]] = arith.muli %[[VAL_3]], %[[VAL_0]] : index
+// CHECK: affine.dma_wait %[[VAL_1]]{{\[}}%[[VAL_4]]], %[[VAL_2]] : memref<1xi32, 4>
+// CHECK: }
``````````
</details>
https://github.com/llvm/llvm-project/pull/123542
More information about the Mlir-commits
mailing list