[Mlir-commits] [mlir] [MLIR][Affine] Update ::fold() to have constant basis attr for affine.delinearize_index/linearize_index wherever applicable (PR #117572)

Abhishek Varma llvmlistbot at llvm.org
Tue Nov 26 22:12:38 PST 2024


https://github.com/Abhishek-Varma updated https://github.com/llvm/llvm-project/pull/117572

>From 7dc8a7d83f875d8357cc40c80d59bdd54842839b Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Mon, 25 Nov 2024 15:47:26 +0000
Subject: [PATCH 1/3] [MLIR][Affine] Add canonicalization pattern to have CST
 basis attr

-- This commit adds canonicalization pattern to have constant(CST)
   attribute for affine.delinearize_index/linearize_index op's basis
   wherever applicable.
-- Essentially the patterns check if the mixed basis OpFoldResult
   set contains any constant SSA value and converts it to a constant
   integer attribute instead.

Signed-off-by: Abhishek Varma <abhvarma at amd.com>
---
 mlir/lib/Dialect/Affine/IR/AffineOps.cpp   | 66 +++++++++++++++++++++-
 mlir/test/Dialect/Affine/canonicalize.mlir | 30 ++++++++++
 2 files changed, 94 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 67d7da622a3550..3e82ec00763142 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -4729,12 +4729,55 @@ struct CancelDelinearizeOfLinearizeDisjointExactTail
     return success();
   }
 };
+
+/// Give mixed basis of affine.delinearize_index/linearize_index replace
+/// constant SSA values with constant attribute as OpFoldResult. In case no
+/// change is made to the existing mixed basis set, return failure; success
+/// otherwise.
+static LogicalResult
+fetchNewConstantBasis(PatternRewriter &rewriter,
+                      SmallVector<OpFoldResult> mixedBasis,
+                      SmallVector<OpFoldResult> &newBasis) {
+  // Replace all constant SSA values with the constant attribute.
+  bool hasConstantSSAVal = false;
+  for (OpFoldResult basis : mixedBasis) {
+    std::optional<int64_t> basisVal = getConstantIntValue(basis);
+    if (basisVal && !isa<Attribute>(basis)) {
+      newBasis.push_back(rewriter.getIndexAttr(*basisVal));
+      hasConstantSSAVal = true;
+    } else {
+      newBasis.push_back(basis);
+    }
+  }
+  if (hasConstantSSAVal)
+    return success();
+  return failure();
+}
+
+/// Folds away constant SSA Value with constant Attribute in basis.
+struct ConstantAttributeBasisDelinearizeIndexOpPattern
+    : public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp op,
+                                PatternRewriter &rewriter) const override {
+    // Replace all constant SSA values with the constant attribute.
+    SmallVector<OpFoldResult> newBasis;
+    if (failed(fetchNewConstantBasis(rewriter, op.getMixedBasis(), newBasis)))
+      return rewriter.notifyMatchFailure(op, "no constant SSA value in basis");
+
+    rewriter.replaceOpWithNewOp<affine::AffineDelinearizeIndexOp>(
+        op, op.getLinearIndex(), newBasis, op.hasOuterBound());
+    return success();
+  }
+};
 } // namespace
 
 void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
     RewritePatternSet &patterns, MLIRContext *context) {
   patterns.insert<CancelDelinearizeOfLinearizeDisjointExactTail,
-                  DropUnitExtentBasis>(context);
+                  DropUnitExtentBasis,
+                  ConstantAttributeBasisDelinearizeIndexOpPattern>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -4959,12 +5002,31 @@ struct DropLinearizeLeadingZero final
     return success();
   }
 };
+
+/// Folds away constant SSA Value with constant Attribute in basis.
+struct ConstantAttributeBasisLinearizeIndexOpPattern
+    : public OpRewritePattern<affine::AffineLinearizeIndexOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
+                                PatternRewriter &rewriter) const override {
+    // Replace all constant SSA values with the constant attribute.
+    SmallVector<OpFoldResult> newBasis;
+    if (failed(fetchNewConstantBasis(rewriter, op.getMixedBasis(), newBasis)))
+      return rewriter.notifyMatchFailure(op, "no constant SSA value in basis");
+
+    rewriter.replaceOpWithNewOp<affine::AffineLinearizeIndexOp>(
+        op, op.getMultiIndex(), newBasis, op.getDisjoint());
+    return success();
+  }
+};
 } // namespace
 
 void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
     RewritePatternSet &patterns, MLIRContext *context) {
   patterns.add<CancelLinearizeOfDelinearizeExact, DropLinearizeLeadingZero,
-               DropLinearizeUnitComponentsIfDisjointOrZero>(context);
+               DropLinearizeUnitComponentsIfDisjointOrZero,
+               ConstantAttributeBasisLinearizeIndexOpPattern>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index 5384977151b47f..16cbce35aeec7e 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -1946,3 +1946,33 @@ func.func @affine_leading_zero_no_outer_bound(%arg0: index, %arg1: index) -> ind
   return %ret : index
 }
 
+// -----
+
+// CHECK-LABEL: @cst_value_to_cst_attr_basis_delinearize_index
+// CHECK-SAME:    (%[[ARG0:.*]]: index)
+// CHECK-DAG:     %[[C0:.+]] = arith.constant 0 : index
+// CHECK:         %[[RET:.*]]:2 = affine.delinearize_index %[[ARG0]] into (3, 4) : index, index
+// CHECK:         return %[[RET]]#0, %[[RET]]#1, %[[C0]] : index, index, index
+func.func @cst_value_to_cst_attr_basis_delinearize_index(%arg0 : index) ->
+    (index, index, index) {
+  %c4 = arith.constant 4 : index
+  %c3 = arith.constant 3 : index
+  %c1 = arith.constant 1 : index
+  %0:3 = affine.delinearize_index %arg0 into (%c3, %c4, %c1)
+      : index, index, index
+  return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// -----
+
+// CHECK-LABEL: @cst_value_to_cst_attr_basis_linearize_index
+// CHECK-SAME:    (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+// CHECK:         %[[RET:.*]] = affine.linearize_index disjoint [%[[ARG1]], %[[ARG2]]] by (3, 4) : index
+// CHECK:         return %[[RET]] : index
+func.func @cst_value_to_cst_attr_basis_linearize_index(%arg0 : index, %arg1 : index, %arg2 : index) ->
+    (index) {
+  %c4 = arith.constant 4 : index
+  %c1 = arith.constant 1 : index
+  %1 = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by  (%c1, 3, %c4) : index
+  return %1 : index
+}

>From a0b8fb021a7600dc86c3ac1688a65174db72d218 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Tue, 26 Nov 2024 06:38:06 +0000
Subject: [PATCH 2/3] Review comment

---
 mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 97 ++++++++----------------
 1 file changed, 33 insertions(+), 64 deletions(-)

diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 3e82ec00763142..e48db6d5f521eb 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -4569,9 +4569,38 @@ LogicalResult AffineDelinearizeIndexOp::verify() {
   return success();
 }
 
+/// Give mixed basis of affine.delinearize_index/linearize_index replace
+/// constant SSA values with the constant integer value and returns the
+/// new static basis.
+static SmallVector<int64_t>
+foldCstValueToCstAttrBasis(ArrayRef<OpFoldResult> mixedBasis,
+                           MutableOperandRange mutableDynamicBasis,
+                           ArrayRef<Attribute> dynamicBasis) {
+  SmallVector<int64_t> staticBasis;
+  for (OpFoldResult basis : mixedBasis) {
+    std::optional<int64_t> basisVal = getConstantIntValue(basis);
+    if (!basisVal)
+      staticBasis.push_back(ShapedType::kDynamic);
+    else
+      staticBasis.push_back(*basisVal);
+  }
+
+  int64_t dynamicBasisIndex = 0;
+  for (OpFoldResult basis : dynamicBasis) {
+    if (basis) {
+      mutableDynamicBasis.erase(dynamicBasisIndex);
+    } else {
+      ++dynamicBasisIndex;
+    }
+  }
+  return staticBasis;
+}
+
 LogicalResult
 AffineDelinearizeIndexOp::fold(FoldAdaptor adaptor,
                                SmallVectorImpl<OpFoldResult> &result) {
+  setStaticBasis(foldCstValueToCstAttrBasis(
+      getMixedBasis(), getDynamicBasisMutable(), adaptor.getDynamicBasis()));
   // If we won't be doing any division or modulo (no basis or the one basis
   // element is purely advisory), simply return the input value.
   if (getNumResults() == 1) {
@@ -4729,55 +4758,12 @@ struct CancelDelinearizeOfLinearizeDisjointExactTail
     return success();
   }
 };
-
-/// Give mixed basis of affine.delinearize_index/linearize_index replace
-/// constant SSA values with constant attribute as OpFoldResult. In case no
-/// change is made to the existing mixed basis set, return failure; success
-/// otherwise.
-static LogicalResult
-fetchNewConstantBasis(PatternRewriter &rewriter,
-                      SmallVector<OpFoldResult> mixedBasis,
-                      SmallVector<OpFoldResult> &newBasis) {
-  // Replace all constant SSA values with the constant attribute.
-  bool hasConstantSSAVal = false;
-  for (OpFoldResult basis : mixedBasis) {
-    std::optional<int64_t> basisVal = getConstantIntValue(basis);
-    if (basisVal && !isa<Attribute>(basis)) {
-      newBasis.push_back(rewriter.getIndexAttr(*basisVal));
-      hasConstantSSAVal = true;
-    } else {
-      newBasis.push_back(basis);
-    }
-  }
-  if (hasConstantSSAVal)
-    return success();
-  return failure();
-}
-
-/// Folds away constant SSA Value with constant Attribute in basis.
-struct ConstantAttributeBasisDelinearizeIndexOpPattern
-    : public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp op,
-                                PatternRewriter &rewriter) const override {
-    // Replace all constant SSA values with the constant attribute.
-    SmallVector<OpFoldResult> newBasis;
-    if (failed(fetchNewConstantBasis(rewriter, op.getMixedBasis(), newBasis)))
-      return rewriter.notifyMatchFailure(op, "no constant SSA value in basis");
-
-    rewriter.replaceOpWithNewOp<affine::AffineDelinearizeIndexOp>(
-        op, op.getLinearIndex(), newBasis, op.hasOuterBound());
-    return success();
-  }
-};
 } // namespace
 
 void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
     RewritePatternSet &patterns, MLIRContext *context) {
   patterns.insert<CancelDelinearizeOfLinearizeDisjointExactTail,
-                  DropUnitExtentBasis,
-                  ConstantAttributeBasisDelinearizeIndexOpPattern>(context);
+                  DropUnitExtentBasis>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -4832,6 +4818,8 @@ LogicalResult AffineLinearizeIndexOp::verify() {
 }
 
 OpFoldResult AffineLinearizeIndexOp::fold(FoldAdaptor adaptor) {
+  setStaticBasis(foldCstValueToCstAttrBasis(
+      getMixedBasis(), getDynamicBasisMutable(), adaptor.getDynamicBasis()));
   // No indices linearizes to zero.
   if (getMultiIndex().empty())
     return IntegerAttr::get(getResult().getType(), 0);
@@ -5002,31 +4990,12 @@ struct DropLinearizeLeadingZero final
     return success();
   }
 };
-
-/// Folds away constant SSA Value with constant Attribute in basis.
-struct ConstantAttributeBasisLinearizeIndexOpPattern
-    : public OpRewritePattern<affine::AffineLinearizeIndexOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
-                                PatternRewriter &rewriter) const override {
-    // Replace all constant SSA values with the constant attribute.
-    SmallVector<OpFoldResult> newBasis;
-    if (failed(fetchNewConstantBasis(rewriter, op.getMixedBasis(), newBasis)))
-      return rewriter.notifyMatchFailure(op, "no constant SSA value in basis");
-
-    rewriter.replaceOpWithNewOp<affine::AffineLinearizeIndexOp>(
-        op, op.getMultiIndex(), newBasis, op.getDisjoint());
-    return success();
-  }
-};
 } // namespace
 
 void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
     RewritePatternSet &patterns, MLIRContext *context) {
   patterns.add<CancelLinearizeOfDelinearizeExact, DropLinearizeLeadingZero,
-               DropLinearizeUnitComponentsIfDisjointOrZero,
-               ConstantAttributeBasisLinearizeIndexOpPattern>(context);
+               DropLinearizeUnitComponentsIfDisjointOrZero>(context);
 }
 
 //===----------------------------------------------------------------------===//

>From fb10fe82eb2f71181e3a660a61919a6388d91abd Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Wed, 27 Nov 2024 06:09:01 +0000
Subject: [PATCH 3/3] Review comment - std::nullopt

---
 mlir/lib/Dialect/Affine/IR/AffineOps.cpp   | 50 +++++++++++++++-------
 mlir/test/Dialect/Affine/canonicalize.mlir | 17 ++++----
 2 files changed, 42 insertions(+), 25 deletions(-)

diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index e48db6d5f521eb..7341057f4e7226 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -4569,13 +4569,27 @@ LogicalResult AffineDelinearizeIndexOp::verify() {
   return success();
 }
 
-/// Give mixed basis of affine.delinearize_index/linearize_index replace
-/// constant SSA values with the constant integer value and returns the
-/// new static basis.
-static SmallVector<int64_t>
+/// Given mixed basis of affine.delinearize_index/linearize_index replace
+/// constant SSA values with the constant integer value and return the new
+/// static basis. In case no such candidate for replacement exists, this utility
+/// returns std::nullopt.
+static std::optional<SmallVector<int64_t>>
 foldCstValueToCstAttrBasis(ArrayRef<OpFoldResult> mixedBasis,
                            MutableOperandRange mutableDynamicBasis,
                            ArrayRef<Attribute> dynamicBasis) {
+  int64_t dynamicBasisIndex = 0;
+  for (OpFoldResult basis : dynamicBasis) {
+    if (basis) {
+      mutableDynamicBasis.erase(dynamicBasisIndex);
+    } else {
+      ++dynamicBasisIndex;
+    }
+  }
+
+  // No constant SSA value exists.
+  if (dynamicBasisIndex == dynamicBasis.size())
+    return std::nullopt;
+
   SmallVector<int64_t> staticBasis;
   for (OpFoldResult basis : mixedBasis) {
     std::optional<int64_t> basisVal = getConstantIntValue(basis);
@@ -4585,22 +4599,19 @@ foldCstValueToCstAttrBasis(ArrayRef<OpFoldResult> mixedBasis,
       staticBasis.push_back(*basisVal);
   }
 
-  int64_t dynamicBasisIndex = 0;
-  for (OpFoldResult basis : dynamicBasis) {
-    if (basis) {
-      mutableDynamicBasis.erase(dynamicBasisIndex);
-    } else {
-      ++dynamicBasisIndex;
-    }
-  }
   return staticBasis;
 }
 
 LogicalResult
 AffineDelinearizeIndexOp::fold(FoldAdaptor adaptor,
                                SmallVectorImpl<OpFoldResult> &result) {
-  setStaticBasis(foldCstValueToCstAttrBasis(
-      getMixedBasis(), getDynamicBasisMutable(), adaptor.getDynamicBasis()));
+  std::optional<SmallVector<int64_t>> maybeStaticBasis =
+      foldCstValueToCstAttrBasis(getMixedBasis(), getDynamicBasisMutable(),
+                                 adaptor.getDynamicBasis());
+  if (maybeStaticBasis) {
+    setStaticBasis(*maybeStaticBasis);
+    return success();
+  }
   // If we won't be doing any division or modulo (no basis or the one basis
   // element is purely advisory), simply return the input value.
   if (getNumResults() == 1) {
@@ -4818,8 +4829,15 @@ LogicalResult AffineLinearizeIndexOp::verify() {
 }
 
 OpFoldResult AffineLinearizeIndexOp::fold(FoldAdaptor adaptor) {
-  setStaticBasis(foldCstValueToCstAttrBasis(
-      getMixedBasis(), getDynamicBasisMutable(), adaptor.getDynamicBasis()));
+  std::optional<SmallVector<int64_t>> maybeStaticBasis =
+      foldCstValueToCstAttrBasis(getMixedBasis(), getDynamicBasisMutable(),
+                                 adaptor.getDynamicBasis());
+  if (maybeStaticBasis) {
+    setStaticBasis(*maybeStaticBasis);
+    return getResult();
+  }
+  // setStaticBasis(foldCstValueToCstAttrBasis(
+  //     getMixedBasis(), getDynamicBasisMutable(), adaptor.getDynamicBasis()));
   // No indices linearizes to zero.
   if (getMultiIndex().empty())
     return IntegerAttr::get(getResult().getType(), 0);
diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index 16cbce35aeec7e..b747178c5b1a94 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -1950,15 +1950,14 @@ func.func @affine_leading_zero_no_outer_bound(%arg0: index, %arg1: index) -> ind
 
 // CHECK-LABEL: @cst_value_to_cst_attr_basis_delinearize_index
 // CHECK-SAME:    (%[[ARG0:.*]]: index)
-// CHECK-DAG:     %[[C0:.+]] = arith.constant 0 : index
-// CHECK:         %[[RET:.*]]:2 = affine.delinearize_index %[[ARG0]] into (3, 4) : index, index
-// CHECK:         return %[[RET]]#0, %[[RET]]#1, %[[C0]] : index, index, index
+// CHECK:         %[[RET:.*]]:3 = affine.delinearize_index %[[ARG0]] into (3, 4, 2) : index, index
+// CHECK:         return %[[RET]]#0, %[[RET]]#1, %[[RET]]#2 : index, index, index
 func.func @cst_value_to_cst_attr_basis_delinearize_index(%arg0 : index) ->
     (index, index, index) {
   %c4 = arith.constant 4 : index
   %c3 = arith.constant 3 : index
-  %c1 = arith.constant 1 : index
-  %0:3 = affine.delinearize_index %arg0 into (%c3, %c4, %c1)
+  %c2 = arith.constant 2 : index
+  %0:3 = affine.delinearize_index %arg0 into (%c3, %c4, %c2)
       : index, index, index
   return %0#0, %0#1, %0#2 : index, index, index
 }
@@ -1967,12 +1966,12 @@ func.func @cst_value_to_cst_attr_basis_delinearize_index(%arg0 : index) ->
 
 // CHECK-LABEL: @cst_value_to_cst_attr_basis_linearize_index
 // CHECK-SAME:    (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
-// CHECK:         %[[RET:.*]] = affine.linearize_index disjoint [%[[ARG1]], %[[ARG2]]] by (3, 4) : index
+// CHECK:         %[[RET:.*]] = affine.linearize_index disjoint [%[[ARG0]], %[[ARG1]], %[[ARG2]]] by (2, 3, 4) : index
 // CHECK:         return %[[RET]] : index
 func.func @cst_value_to_cst_attr_basis_linearize_index(%arg0 : index, %arg1 : index, %arg2 : index) ->
     (index) {
   %c4 = arith.constant 4 : index
-  %c1 = arith.constant 1 : index
-  %1 = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by  (%c1, 3, %c4) : index
-  return %1 : index
+  %c2 = arith.constant 2 : index
+  %0 = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by  (%c2, 3, %c4) : index
+  return %0 : index
 }



More information about the Mlir-commits mailing list