[Mlir-commits] [mlir] [mlir][vector] Constant fold scalable `affine.min` (PR #106752)

Benjamin Maxwell llvmlistbot at llvm.org
Fri Aug 30 08:43:20 PDT 2024


https://github.com/MacDue created https://github.com/llvm/llvm-project/pull/106752

None

>From 60e7583f239bf5609f577ea8053a0b23a0f0bd05 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Fri, 30 Aug 2024 15:31:57 +0000
Subject: [PATCH] [mlir][vector] Constant fold scalable `affine.min`

---
 .../mlir/Dialect/Vector/IR/VectorOps.td       |   2 +
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 116 ++++++++++++++++++
 mlir/test/Dialect/Vector/canonicalize.mlir    |  20 +++
 3 files changed, 138 insertions(+)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index b96f5c2651bce5..5ebb7e3307934f 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2857,6 +2857,8 @@ def VectorScaleOp : Vector_Op<"vscale",
       setNameFn(getResult(), "vscale");
     }
   }];
+
+  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 44bd4aa76ffbd6..27a5a35280d0d7 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -13,6 +13,7 @@
 
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
@@ -706,6 +707,121 @@ void ReductionOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<ElideSingleElementReduction>(context);
 }
 
+//===----------------------------------------------------------------------===//
+// VectorScaleOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+class FoldScalableAffineMin : public OpRewritePattern<affine::AffineMinOp> {
+  using OpRewritePattern<affine::AffineMinOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(affine::AffineMinOp affineMin,
+                                PatternRewriter &rewriter) const override {
+    if (affineMin.getDimOperands().size() != 1)
+      return failure();
+
+    if (affineMin.getSymbolOperands().size() != 1)
+      return failure();
+
+    Value symbolOperand = affineMin.getSymbolOperands()[0];
+    Value dimOperand = affineMin.getDimOperands()[0];
+
+    auto symbolVscaleMultiple =
+        vector::getConstantVscaleMultiplier(symbolOperand);
+    if (!symbolVscaleMultiple)
+      return failure();
+
+    auto loop = affineMin->getParentOfType<LoopLikeOpInterface>();
+    if (!loop)
+      return failure();
+    auto inductionVar = loop.getSingleInductionVar();
+    auto lowerBound = loop.getSingleLowerBound();
+    auto upperBound = loop.getSingleUpperBound();
+    auto step = loop.getSingleStep();
+
+    if (!inductionVar || !lowerBound || !upperBound || !step)
+      return failure();
+
+    if (*inductionVar != dimOperand)
+      return failure();
+
+    if (getConstantIntValue(*step) != symbolVscaleMultiple)
+      return failure();
+
+    if (!isZeroIndex(*lowerBound))
+      return failure();
+
+    auto upperBoundValue = llvm::dyn_cast_if_present<Value>(*upperBound);
+    if (!upperBoundValue)
+      return failure();
+
+    auto upperBoundVscaleMultiple =
+        vector::getConstantVscaleMultiplier(upperBoundValue);
+
+    if (upperBoundVscaleMultiple != symbolVscaleMultiple)
+      return failure();
+
+    auto map = affineMin.getAffineMap();
+
+    auto isSymbolMinusDim = [](AffineExpr expr) {
+      auto binop = dyn_cast<AffineBinaryOpExpr>(expr);
+      if (!binop || binop.getKind() != AffineExprKind::Add)
+        return false;
+      if (!isa<AffineSymbolExpr>(binop.getRHS()))
+        return false;
+      auto neg = dyn_cast<AffineBinaryOpExpr>(binop.getLHS());
+      if (!neg || neg.getKind() != AffineExprKind::Mul)
+        return false;
+      if (!isa<AffineDimExpr>(neg.getLHS()))
+        return false;
+      auto cst = dyn_cast<AffineConstantExpr>(neg.getRHS());
+      if (!cst || cst.getValue() != -1)
+        return false;
+      return true;
+    };
+
+    if (!isSymbolMinusDim(map.getResult(0)))
+      return failure();
+    auto cst = dyn_cast<AffineConstantExpr>(map.getResult(1));
+    if (!cst || cst.getValue() != upperBoundVscaleMultiple)
+      return failure();
+
+    // Otherwise, we know:
+    // inductionVar >= 0
+    // inductionVar < cst * vscale
+    // step == cst
+    //
+    // So:
+    // inductionVar == x * vscale
+    // x >= 0
+    // x < vscale
+    //
+    // symbolOperand == cst * vscale
+    // dimOperand = inductionVar = x * cst
+    //
+    // min(-d0 + s0, cst)
+    // = min(-(cst * x) + (cst * vscale), 8)
+    // = min(cst*(vscale - x), cst)
+    // vscale - x >= 1, so cst*(vscale - x) >= cst
+    // so min(cst*(vscale - x), cst) == cst
+
+    rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(affineMin,
+                                                        *symbolVscaleMultiple);
+    return success();
+  }
+};
+
+} // namespace
+
+void VectorScaleOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                                MLIRContext *context) {
+  // FIXME: This is not _really_ a vector.vscale pattern (though a
+  // vector.vscale op will always be present when this fold applies), but it is
+  // here for lack of a better place.
+  results.add<FoldScalableAffineMin>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // ContractionOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index e71a6eb02ea46c..5cd7b14f8a9c6c 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2740,3 +2740,23 @@ func.func @vector_insert_const_regression(%arg0: i8) -> vector<4xi8> {
   %1 = vector.insert %arg0, %0 [0] : i8 into vector<4xi8>
   return %1 : vector<4xi8>
 }
+
+// -----
+
+// CHECK-LABEL: @redundant_scalable_affine_min
+//       CHECK:   %[[C8:.*]] = arith.constant 8 : index
+//       CHECK:   scf.for
+//   CHECK-NOT:     affine.min
+//       CHECK:     "test.some_use"(%[[C8]]) : (index) -> ()
+//   CHECK-NOT:     affine.min
+func.func @redundant_scalable_affine_min() {
+  %c0 = arith.constant 0 : index
+  %c8 = arith.constant 8 : index
+  %vscale = vector.vscale
+  %c8_vscale = arith.muli %c8, %vscale : index
+  scf.for %i = %c0 to %c8_vscale step %c8 {
+    %min = affine.min affine_map<(d0)[s0] -> (-d0 + s0, 8)>(%i)[%c8_vscale]
+    "test.some_use"(%min) : (index) -> ()
+  }
+  return
+}



More information about the Mlir-commits mailing list