[Mlir-commits] [mlir] d83148f - [MLIR][Affine] Update ::fold() to have constant basis attr for affine.delinearize_index/linearize_index (#117572)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Nov 28 20:52:26 PST 2024


Author: Abhishek Varma
Date: 2024-11-29T10:22:23+05:30
New Revision: d83148f9b9debde1358f0686594da208ce33182e

URL: https://github.com/llvm/llvm-project/commit/d83148f9b9debde1358f0686594da208ce33182e
DIFF: https://github.com/llvm/llvm-project/commit/d83148f9b9debde1358f0686594da208ce33182e.diff

LOG: [MLIR][Affine] Update ::fold() to have constant basis attr for affine.delinearize_index/linearize_index (#117572)

-- This commit updates `::fold()` to have constant(CST)
   attribute for affine.delinearize_index/linearize_index op's basis
   wherever applicable.
-- Essentially the code checks if the mixed basis OpFoldResult
   set contains any constant SSA value and converts it to a constant
   integer instead.

Signed-off-by: Abhishek Varma <abhvarma at amd.com>

Added: 
    

Modified: 
    mlir/lib/Dialect/Affine/IR/AffineOps.cpp
    mlir/test/Dialect/Affine/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 1c5466730a5589..03549fb2e0fa91 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -4569,9 +4569,49 @@ LogicalResult AffineDelinearizeIndexOp::verify() {
   return success();
 }
 
+/// Given mixed basis of affine.delinearize_index/linearize_index replace
+/// constant SSA values with the constant integer value and return the new
+/// static basis. In case no such candidate for replacement exists, this utility
+/// returns std::nullopt.
+static std::optional<SmallVector<int64_t>>
+foldCstValueToCstAttrBasis(ArrayRef<OpFoldResult> mixedBasis,
+                           MutableOperandRange mutableDynamicBasis,
+                           ArrayRef<Attribute> dynamicBasis) {
+  int64_t dynamicBasisIndex = 0;
+  for (OpFoldResult basis : dynamicBasis) {
+    if (basis) {
+      mutableDynamicBasis.erase(dynamicBasisIndex);
+    } else {
+      ++dynamicBasisIndex;
+    }
+  }
+
+  // No constant SSA value exists.
+  if (dynamicBasisIndex == dynamicBasis.size())
+    return std::nullopt;
+
+  SmallVector<int64_t> staticBasis;
+  for (OpFoldResult basis : mixedBasis) {
+    std::optional<int64_t> basisVal = getConstantIntValue(basis);
+    if (!basisVal)
+      staticBasis.push_back(ShapedType::kDynamic);
+    else
+      staticBasis.push_back(*basisVal);
+  }
+
+  return staticBasis;
+}
+
 LogicalResult
 AffineDelinearizeIndexOp::fold(FoldAdaptor adaptor,
                                SmallVectorImpl<OpFoldResult> &result) {
+  std::optional<SmallVector<int64_t>> maybeStaticBasis =
+      foldCstValueToCstAttrBasis(getMixedBasis(), getDynamicBasisMutable(),
+                                 adaptor.getDynamicBasis());
+  if (maybeStaticBasis) {
+    setStaticBasis(*maybeStaticBasis);
+    return success();
+  }
   // If we won't be doing any division or modulo (no basis or the one basis
   // element is purely advisory), simply return the input value.
   if (getNumResults() == 1) {
@@ -4875,6 +4915,13 @@ LogicalResult AffineLinearizeIndexOp::verify() {
 }
 
 OpFoldResult AffineLinearizeIndexOp::fold(FoldAdaptor adaptor) {
+  std::optional<SmallVector<int64_t>> maybeStaticBasis =
+      foldCstValueToCstAttrBasis(getMixedBasis(), getDynamicBasisMutable(),
+                                 adaptor.getDynamicBasis());
+  if (maybeStaticBasis) {
+    setStaticBasis(*maybeStaticBasis);
+    return getResult();
+  }
   // No indices linearizes to zero.
   if (getMultiIndex().empty())
     return IntegerAttr::get(getResult().getType(), 0);

diff  --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index d3f61f7e503f9b..717004eb50c0fc 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -2012,3 +2012,32 @@ func.func @affine_leading_zero_no_outer_bound(%arg0: index, %arg1: index) -> ind
   return %ret : index
 }
 
+// -----
+
+// CHECK-LABEL: @cst_value_to_cst_attr_basis_delinearize_index
+// CHECK-SAME:    (%[[ARG0:.*]]: index)
+// CHECK:         %[[RET:.*]]:3 = affine.delinearize_index %[[ARG0]] into (3, 4, 2) : index, index
+// CHECK:         return %[[RET]]#0, %[[RET]]#1, %[[RET]]#2 : index, index, index
+func.func @cst_value_to_cst_attr_basis_delinearize_index(%arg0 : index) ->
+    (index, index, index) {
+  %c4 = arith.constant 4 : index
+  %c3 = arith.constant 3 : index
+  %c2 = arith.constant 2 : index
+  %0:3 = affine.delinearize_index %arg0 into (%c3, %c4, %c2)
+      : index, index, index
+  return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// -----
+
+// CHECK-LABEL: @cst_value_to_cst_attr_basis_linearize_index
+// CHECK-SAME:    (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+// CHECK:         %[[RET:.*]] = affine.linearize_index disjoint [%[[ARG0]], %[[ARG1]], %[[ARG2]]] by (2, 3, 4) : index
+// CHECK:         return %[[RET]] : index
+func.func @cst_value_to_cst_attr_basis_linearize_index(%arg0 : index, %arg1 : index, %arg2 : index) ->
+    (index) {
+  %c4 = arith.constant 4 : index
+  %c2 = arith.constant 2 : index
+  %0 = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by  (%c2, 3, %c4) : index
+  return %0 : index
+}


        


More information about the Mlir-commits mailing list