[Mlir-commits] [mlir] [mlir][vector] Constant fold scalable `affine.min` (PR #106752)
Benjamin Maxwell
llvmlistbot at llvm.org
Fri Aug 30 08:43:20 PDT 2024
https://github.com/MacDue created https://github.com/llvm/llvm-project/pull/106752
None
>From 60e7583f239bf5609f577ea8053a0b23a0f0bd05 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Fri, 30 Aug 2024 15:31:57 +0000
Subject: [PATCH] [mlir][vector] Constant fold scalable `affine.min`
---
.../mlir/Dialect/Vector/IR/VectorOps.td | 2 +
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 116 ++++++++++++++++++
mlir/test/Dialect/Vector/canonicalize.mlir | 20 +++
3 files changed, 138 insertions(+)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index b96f5c2651bce5..5ebb7e3307934f 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2857,6 +2857,8 @@ def VectorScaleOp : Vector_Op<"vscale",
setNameFn(getResult(), "vscale");
}
}];
+
+ let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 44bd4aa76ffbd6..27a5a35280d0d7 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -13,6 +13,7 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
@@ -706,6 +707,121 @@ void ReductionOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<ElideSingleElementReduction>(context);
}
+//===----------------------------------------------------------------------===//
+// VectorScaleOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+class FoldScalableAffineMin : public OpRewritePattern<affine::AffineMinOp> {
+ using OpRewritePattern<affine::AffineMinOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(affine::AffineMinOp affineMin,
+ PatternRewriter &rewriter) const override {
+ if (affineMin.getDimOperands().size() != 1)
+ return failure();
+
+ if (affineMin.getSymbolOperands().size() != 1)
+ return failure();
+
+ Value symbolOperand = affineMin.getSymbolOperands()[0];
+ Value dimOperand = affineMin.getDimOperands()[0];
+
+ auto symbolVscaleMultiple =
+ vector::getConstantVscaleMultiplier(symbolOperand);
+ if (!symbolVscaleMultiple)
+ return failure();
+
+ auto loop = affineMin->getParentOfType<LoopLikeOpInterface>();
+ if (!loop)
+ return failure();
+ auto inductionVar = loop.getSingleInductionVar();
+ auto lowerBound = loop.getSingleLowerBound();
+ auto upperBound = loop.getSingleUpperBound();
+ auto step = loop.getSingleStep();
+
+ if (!inductionVar || !lowerBound || !upperBound || !step)
+ return failure();
+
+ if (*inductionVar != dimOperand)
+ return failure();
+
+ if (getConstantIntValue(*step) != symbolVscaleMultiple)
+ return failure();
+
+ if (!isZeroIndex(*lowerBound))
+ return failure();
+
+ auto upperBoundValue = llvm::dyn_cast_if_present<Value>(*upperBound);
+ if (!upperBoundValue)
+ return failure();
+
+ auto upperBoundVscaleMultiple =
+ vector::getConstantVscaleMultiplier(upperBoundValue);
+
+ if (upperBoundVscaleMultiple != symbolVscaleMultiple)
+ return failure();
+
+ auto map = affineMin.getAffineMap();
+
+ auto isSymbolMinusDim = [](AffineExpr expr) {
+ auto binop = dyn_cast<AffineBinaryOpExpr>(expr);
+ if (!binop || binop.getKind() != AffineExprKind::Add)
+ return false;
+ if (!isa<AffineSymbolExpr>(binop.getRHS()))
+ return false;
+ auto neg = dyn_cast<AffineBinaryOpExpr>(binop.getLHS());
+ if (!neg || neg.getKind() != AffineExprKind::Mul)
+ return false;
+ if (!isa<AffineDimExpr>(neg.getLHS()))
+ return false;
+ auto cst = dyn_cast<AffineConstantExpr>(neg.getRHS());
+ if (!cst || cst.getValue() != -1)
+ return false;
+ return true;
+ };
+
+ if (!isSymbolMinusDim(map.getResult(0)))
+ return failure();
+ auto cst = dyn_cast<AffineConstantExpr>(map.getResult(1));
+ if (!cst || cst.getValue() != upperBoundVscaleMultiple)
+ return failure();
+
+ // Otherwise, we know:
+ // inductionVar >= 0
+ // inductionVar < cst * vscale
+ // step == cst
+ //
+ // So:
+ // inductionVar == x * vscale
+ // x >= 0
+ // x < vscale
+ //
+ // symbolOperand == cst * vscale
+ // dimOperand = inductionVar = x * cst
+ //
+ // min(-d0 + s0, cst)
+ // = min(-(cst * x) + (cst * vscale), 8)
+ // = min(cst*(vscale - x), cst)
+ // vscale - x >= 1, so cst*(vscale - x) >= cst
+ // so min(cst*(vscale - x), cst) == cst
+
+ rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(affineMin,
+ *symbolVscaleMultiple);
+ return success();
+ }
+};
+
+} // namespace
+
+void VectorScaleOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ // FIXME: This is not _really_ a vector.vscale pattern (though a
+ // vector.vscale op will always be present when this fold applies), but it is
+ // here for lack of a better place.
+ results.add<FoldScalableAffineMin>(context);
+}
+
//===----------------------------------------------------------------------===//
// ContractionOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index e71a6eb02ea46c..5cd7b14f8a9c6c 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2740,3 +2740,23 @@ func.func @vector_insert_const_regression(%arg0: i8) -> vector<4xi8> {
%1 = vector.insert %arg0, %0 [0] : i8 into vector<4xi8>
return %1 : vector<4xi8>
}
+
+// -----
+
+// CHECK-LABEL: @redundant_scalable_affine_min
+// CHECK: %[[C8:.*]] = arith.constant 8 : index
+// CHECK: scf.for
+// CHECK-NOT: affine.min
+// CHECK: "test.some_use"(%[[C8]]) : (index) -> ()
+// CHECK-NOT: affine.min
+func.func @redundant_scalable_affine_min() {
+ %c0 = arith.constant 0 : index
+ %c8 = arith.constant 8 : index
+ %vscale = vector.vscale
+ %c8_vscale = arith.muli %c8, %vscale : index
+ scf.for %i = %c0 to %c8_vscale step %c8 {
+ %min = affine.min affine_map<(d0)[s0] -> (-d0 + s0, 8)>(%i)[%c8_vscale]
+ "test.some_use"(%min) : (index) -> ()
+ }
+ return
+}
More information about the Mlir-commits
mailing list