[Mlir-commits] [mlir] [mlir][affine] Remove `isValidAffineIndexOperand` (PR #73027)

Rik Huijzer llvmlistbot at llvm.org
Tue Nov 21 11:13:28 PST 2023


https://github.com/rikhuijzer created https://github.com/llvm/llvm-project/pull/73027

The function
```cpp
static bool isValidAffineIndexOperand(Value value, Region *region) {
  return isValidDim(value, region) || isValidSymbol(value, region);
}
```
is redundant because `isValidDim` is defined as
```cpp
bool mlir::affine::isValidDim(Value value, Region *region) {
  // The value must be an index type.
  if (!value.getType().isIndex())
    return false;

  // All valid symbols are okay.
  if (isValidSymbol(value, region))
    return true;

  auto *op = value.getDefiningOp();
  if (!op) {
    // This value has to be a block argument for an affine.for or an
    // affine.parallel.
    auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
    return isa<AffineForOp, AffineParallelOp>(parentOp);
  }

  // Affine apply operation is ok if all of its operands are ok.
  if (auto applyOp = dyn_cast<AffineApplyOp>(op))
    return applyOp.isValidDim(region);
  // The dim op is okay if its operand memref/tensor is defined at the top
  // level.
  if (auto dimOp = dyn_cast<ShapedDimOpInterface>(op))
    return isTopLevelValue(dimOp.getShapedValue());
  return false;
}
```
and `isValidSymbol` is defined as
```cpp
bool mlir::affine::isValidSymbol(Value value, Region *region) {
  // The value must be an index type.
  if (!value.getType().isIndex())
    return false;

  [...]
}
```

To see that the function is redundant, consider the following cases:

- `isValidDim(value, region)` is true, then `isValidDim(value, region) || isValidSymbol(value, region)` must be true.
- `isValidDim(value, region)` is false, then either `value.getType().isIndex()` is false, which means that both `isValidDim` and `isValidSymbol` must be false, or `value.getType().isIndex()` is true, but then `isValidSymbol` must be false too or `isValidDim(value, region)` wouldn't be false.

Or, put differently, consider the following cases:

- `value.getType().isIndex()` is false, then both `isValidDim(value, region) || isValidSymbol(value, region)` are false.
- `value.getType().isIndex()` is true, then either `isValidDim` is false which implies that `isValidSymbol` must be false or `isValidDim` is true which means that `isValidDim(value, region) || isValidSymbol(value, region)` must be true.

In all cases, `isValidDim(value, region) || isValidSymbol(value, region)` is equivalent to `isValidDim(value, region)`.

>From 007482c978606d47b677a9668c5a9c7b0527ae35 Mon Sep 17 00:00:00 2001
From: Rik Huijzer <github at huijzer.xyz>
Date: Tue, 21 Nov 2023 19:50:40 +0100
Subject: [PATCH] [mlir][affine] Remove `isValidAffineIndexOperand`

---
 mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 34 +++++++++---------------
 1 file changed, 13 insertions(+), 21 deletions(-)

diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index d22a7539fb75018..61c66361ce1fb32 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -434,13 +434,6 @@ bool mlir::affine::isValidSymbol(Value value, Region *region) {
   return false;
 }
 
-// Returns true if 'value' is a valid index to an affine operation (e.g.
-// affine.load, affine.store, affine.dma_start, affine.dma_wait) where
-// `region` provides the polyhedral symbol scope. Returns false otherwise.
-static bool isValidAffineIndexOperand(Value value, Region *region) {
-  return isValidDim(value, region) || isValidSymbol(value, region);
-}
-
 /// Prints dimension and symbol list.
 static void printDimAndSymbolList(Operation::operand_iterator begin,
                                   Operation::operand_iterator end,
@@ -1650,19 +1643,19 @@ LogicalResult AffineDmaStartOp::verifyInvariantsImpl() {
   for (auto idx : getSrcIndices()) {
     if (!idx.getType().isIndex())
       return emitOpError("src index to dma_start must have 'index' type");
-    if (!isValidAffineIndexOperand(idx, scope))
+    if (!isValidDim(idx, scope))
       return emitOpError("src index must be a dimension or symbol identifier");
   }
   for (auto idx : getDstIndices()) {
     if (!idx.getType().isIndex())
       return emitOpError("dst index to dma_start must have 'index' type");
-    if (!isValidAffineIndexOperand(idx, scope))
+    if (!isValidDim(idx, scope))
       return emitOpError("dst index must be a dimension or symbol identifier");
   }
   for (auto idx : getTagIndices()) {
     if (!idx.getType().isIndex())
       return emitOpError("tag index to dma_start must have 'index' type");
-    if (!isValidAffineIndexOperand(idx, scope))
+    if (!isValidDim(idx, scope))
       return emitOpError("tag index must be a dimension or symbol identifier");
   }
   return success();
@@ -1751,7 +1744,7 @@ LogicalResult AffineDmaWaitOp::verifyInvariantsImpl() {
   for (auto idx : getTagIndices()) {
     if (!idx.getType().isIndex())
       return emitOpError("index to dma_wait must have 'index' type");
-    if (!isValidAffineIndexOperand(idx, scope))
+    if (!isValidDim(idx, scope))
       return emitOpError("index must be a dimension or symbol identifier");
   }
   return success();
@@ -2913,8 +2906,7 @@ static void composeSetAndOperands(IntegerSet &set,
 }
 
 /// Canonicalize an affine if op's conditional (integer set + operands).
-LogicalResult AffineIfOp::fold(FoldAdaptor,
-                               SmallVectorImpl<OpFoldResult> &) {
+LogicalResult AffineIfOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
   auto set = getIntegerSet();
   SmallVector<Value, 4> operands(getOperands());
   composeSetAndOperands(set, operands);
@@ -3005,17 +2997,17 @@ static LogicalResult
 verifyMemoryOpIndexing(Operation *op, AffineMapAttr mapAttr,
                        Operation::operand_range mapOperands,
                        MemRefType memrefType, unsigned numIndexOperands) {
-    AffineMap map = mapAttr.getValue();
-    if (map.getNumResults() != memrefType.getRank())
-      return op->emitOpError("affine map num results must equal memref rank");
-    if (map.getNumInputs() != numIndexOperands)
-      return op->emitOpError("expects as many subscripts as affine map inputs");
+  AffineMap map = mapAttr.getValue();
+  if (map.getNumResults() != memrefType.getRank())
+    return op->emitOpError("affine map num results must equal memref rank");
+  if (map.getNumInputs() != numIndexOperands)
+    return op->emitOpError("expects as many subscripts as affine map inputs");
 
   Region *scope = getAffineScope(op);
-  for (auto idx : mapOperands) {
+  for (Value idx : mapOperands) {
     if (!idx.getType().isIndex())
       return op->emitOpError("index to load must have 'index' type");
-    if (!isValidAffineIndexOperand(idx, scope))
+    if (!isValidDim(idx, scope))
       return op->emitOpError("index must be a dimension or symbol identifier");
   }
 
@@ -3604,7 +3596,7 @@ LogicalResult AffinePrefetchOp::verify() {
 
   Region *scope = getAffineScope(*this);
   for (auto idx : getMapOperands()) {
-    if (!isValidAffineIndexOperand(idx, scope))
+    if (!isValidDim(idx, scope))
       return emitOpError("index must be a dimension or symbol identifier");
   }
   return success();



More information about the Mlir-commits mailing list