[Mlir-commits] [mlir] 49df068 - [mlir][arith][NFC] Simplify narrowing patterns with a wrapper type
Jakub Kuderski
llvmlistbot at llvm.org
Mon May 1 10:35:42 PDT 2023
Author: Jakub Kuderski
Date: 2023-05-01T13:31:31-04:00
New Revision: 49df068836bfbb538771395d8bb293548afd414e
URL: https://github.com/llvm/llvm-project/commit/49df068836bfbb538771395d8bb293548afd414e
DIFF: https://github.com/llvm/llvm-project/commit/49df068836bfbb538771395d8bb293548afd414e.diff
LOG: [mlir][arith][NFC] Simplify narrowing patterns with a wrapper type
Add a new wraper type that represents either of `ExtSIOp` or `ExtUIOp`.
This is to simplify the code by using a single type, so that we do not
have to use templates or branching to handle both extension kinds.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D149485
Added:
Modified:
mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
index 639b19b0a5d8a..c515824a8e04d 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
@@ -15,13 +15,13 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Matchers.h"
+#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
-#include "llvm/ADT/TypeSwitch.h"
#include <cassert>
#include <cstdint>
@@ -100,11 +100,63 @@ FailureOr<unsigned> calculateBitsRequired(Type type) {
enum class ExtensionKind { Sign, Zero };
-ExtensionKind getExtensionKind(Operation *op) {
- assert(op);
- assert((isa<arith::ExtSIOp, arith::ExtUIOp>(op)) && "Not an extension op");
- return isa<arith::ExtSIOp>(op) ? ExtensionKind::Sign : ExtensionKind::Zero;
-}
+/// Wrapper around `arith::ExtSIOp` and `arith::ExtUIOp` ops that abstracts away
+/// the exact op type. Exposes helper functions to query the types, operands,
+/// and the result. This is so that we can handle both extension kinds without
+/// needing to use templates or branching.
+class ExtensionOp {
+public:
+ /// Attemps to create a new extension op from `op`. Returns an extension op
+ /// wrapper when `op` is either `arith.extsi` or `arith.extui`, and failure
+ /// otherwise.
+ static FailureOr<ExtensionOp> from(Operation *op) {
+ if (auto sext = dyn_cast_or_null<arith::ExtSIOp>(op))
+ return ExtensionOp{op, ExtensionKind::Sign};
+ if (auto zext = dyn_cast_or_null<arith::ExtUIOp>(op))
+ return ExtensionOp{op, ExtensionKind::Zero};
+
+ return failure();
+ }
+
+ ExtensionOp(const ExtensionOp &) = default;
+ ExtensionOp &operator=(const ExtensionOp &) = default;
+
+ /// Creates a new extension op of the same kind.
+ Operation *recreate(PatternRewriter &rewriter, Location loc, Type newType,
+ Value in) {
+ if (kind == ExtensionKind::Sign)
+ return rewriter.create<arith::ExtSIOp>(loc, newType, in);
+
+ return rewriter.create<arith::ExtUIOp>(loc, newType, in);
+ }
+
+ /// Replaces `toReplace` with a new extension op of the same kind.
+ void recreateAndReplace(PatternRewriter &rewriter, Operation *toReplace,
+ Value in) {
+ assert(toReplace->getNumResults() == 1);
+ Type newType = toReplace->getResult(0).getType();
+ Operation *newOp = recreate(rewriter, toReplace->getLoc(), newType, in);
+ rewriter.replaceOp(toReplace, newOp->getResult(0));
+ }
+
+ ExtensionKind getKind() { return kind; }
+
+ Value getResult() { return op->getResult(0); }
+ Value getIn() { return op->getOperand(0); }
+
+ Type getType() { return getResult().getType(); }
+ Type getElementType() { return getElementTypeOrSelf(getType()); }
+ Type getInType() { return getIn().getType(); }
+ Type getInElementType() { return getElementTypeOrSelf(getInType()); }
+
+private:
+ ExtensionOp(Operation *op, ExtensionKind kind) : op(op), kind(kind) {
+ assert(op);
+ assert((isa<arith::ExtSIOp, arith::ExtUIOp>(op)) && "Not an extension op");
+ }
+ Operation *op = nullptr;
+ ExtensionKind kind = {};
+};
/// Returns the integer bitwidth required to represent `value`.
unsigned calculateBitsRequired(const APInt &value,
@@ -202,19 +254,15 @@ struct ExtensionOverExtract final : NarrowingPattern<vector::ExtractOp> {
LogicalResult matchAndRewrite(vector::ExtractOp op,
PatternRewriter &rewriter) const override {
- Operation *def = op.getVector().getDefiningOp();
- if (!def)
+ FailureOr<ExtensionOp> ext =
+ ExtensionOp::from(op.getVector().getDefiningOp());
+ if (failed(ext))
return failure();
- return TypeSwitch<Operation *, LogicalResult>(def)
- .Case<arith::ExtSIOp, arith::ExtUIOp>([&](auto extOp) {
- Value newExtract = rewriter.create<vector::ExtractOp>(
- op.getLoc(), extOp.getIn(), op.getPosition());
- rewriter.replaceOpWithNewOp<decltype(extOp)>(op, op.getType(),
- newExtract);
- return success();
- })
- .Default(failure());
+ Value newExtract = rewriter.create<vector::ExtractOp>(
+ op.getLoc(), ext->getIn(), op.getPosition());
+ ext->recreateAndReplace(rewriter, op, newExtract);
+ return success();
}
};
@@ -224,19 +272,15 @@ struct ExtensionOverExtractElement final
LogicalResult matchAndRewrite(vector::ExtractElementOp op,
PatternRewriter &rewriter) const override {
- Operation *def = op.getVector().getDefiningOp();
- if (!def)
+ FailureOr<ExtensionOp> ext =
+ ExtensionOp::from(op.getVector().getDefiningOp());
+ if (failed(ext))
return failure();
- return TypeSwitch<Operation *, LogicalResult>(def)
- .Case<arith::ExtSIOp, arith::ExtUIOp>([&](auto extOp) {
- Value newExtract = rewriter.create<vector::ExtractElementOp>(
- op.getLoc(), extOp.getIn(), op.getPosition());
- rewriter.replaceOpWithNewOp<decltype(extOp)>(op, op.getType(),
- newExtract);
- return success();
- })
- .Default(failure());
+ Value newExtract = rewriter.create<vector::ExtractElementOp>(
+ op.getLoc(), ext->getIn(), op.getPosition());
+ ext->recreateAndReplace(rewriter, op, newExtract);
+ return success();
}
};
@@ -246,24 +290,19 @@ struct ExtensionOverExtractStridedSlice final
LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp op,
PatternRewriter &rewriter) const override {
- Operation *def = op.getVector().getDefiningOp();
- if (!def)
+ FailureOr<ExtensionOp> ext =
+ ExtensionOp::from(op.getVector().getDefiningOp());
+ if (failed(ext))
return failure();
- return TypeSwitch<Operation *, LogicalResult>(def)
- .Case<arith::ExtSIOp, arith::ExtUIOp>([&](auto extOp) {
- VectorType origTy = op.getType();
- Type inElemTy =
- cast<VectorType>(extOp.getIn().getType()).getElementType();
- VectorType extractTy = origTy.cloneWith(origTy.getShape(), inElemTy);
- Value newExtract = rewriter.create<vector::ExtractStridedSliceOp>(
- op.getLoc(), extractTy, extOp.getIn(), op.getOffsets(),
- op.getSizes(), op.getStrides());
- rewriter.replaceOpWithNewOp<decltype(extOp)>(op, op.getType(),
- newExtract);
- return success();
- })
- .Default(failure());
+ VectorType origTy = op.getType();
+ VectorType extractTy =
+ origTy.cloneWith(origTy.getShape(), ext->getInElementType());
+ Value newExtract = rewriter.create<vector::ExtractStridedSliceOp>(
+ op.getLoc(), extractTy, ext->getIn(), op.getOffsets(), op.getSizes(),
+ op.getStrides());
+ ext->recreateAndReplace(rewriter, op, newExtract);
+ return success();
}
};
@@ -272,30 +311,22 @@ struct ExtensionOverInsert final : NarrowingPattern<vector::InsertOp> {
LogicalResult matchAndRewrite(vector::InsertOp op,
PatternRewriter &rewriter) const override {
- Operation *def = op.getSource().getDefiningOp();
- if (!def)
+ FailureOr<ExtensionOp> ext =
+ ExtensionOp::from(op.getSource().getDefiningOp());
+ if (failed(ext))
return failure();
- return TypeSwitch<Operation *, LogicalResult>(def)
- .Case<arith::ExtSIOp, arith::ExtUIOp>([&](auto extOp) {
- // Rewrite the insertion in terms of narrower operands
- // and later extend the result to the original bitwidth.
- FailureOr<vector::InsertOp> newInsert =
- createNarrowInsert(op, rewriter, extOp);
- if (failed(newInsert))
- return failure();
- rewriter.replaceOpWithNewOp<decltype(extOp)>(op, op.getType(),
- *newInsert);
- return success();
- })
- .Default(failure());
+ FailureOr<vector::InsertOp> newInsert =
+ createNarrowInsert(op, rewriter, *ext);
+ if (failed(newInsert))
+ return failure();
+ ext->recreateAndReplace(rewriter, op, *newInsert);
+ return success();
}
FailureOr<vector::InsertOp> createNarrowInsert(vector::InsertOp op,
PatternRewriter &rewriter,
- Operation *insValue) const {
- assert((isa<arith::ExtSIOp, arith::ExtUIOp>(insValue)));
-
+ ExtensionOp insValue) const {
// Calculate the operand and result bitwidths. We can only apply narrowing
// when the inserted source value and destination vector require fewer bits
// than the result. Because the source and destination may have
diff erent
@@ -306,14 +337,13 @@ struct ExtensionOverInsert final : NarrowingPattern<vector::InsertOp> {
if (failed(origBitsRequired))
return failure();
- ExtensionKind kind = getExtensionKind(insValue);
FailureOr<unsigned> destBitsRequired =
- calculateBitsRequired(op.getDest(), kind);
+ calculateBitsRequired(op.getDest(), insValue.getKind());
if (failed(destBitsRequired) || *destBitsRequired >= *origBitsRequired)
return failure();
FailureOr<unsigned> insertedBitsRequired =
- calculateBitsRequired(insValue->getOperands().front(), kind);
+ calculateBitsRequired(insValue.getIn(), insValue.getKind());
if (failed(insertedBitsRequired) ||
*insertedBitsRequired >= *origBitsRequired)
return failure();
@@ -327,13 +357,13 @@ struct ExtensionOverInsert final : NarrowingPattern<vector::InsertOp> {
return failure();
FailureOr<Type> newInsertedValueTy =
- getNarrowType(newInsertionBits, insValue->getResultTypes().front());
+ getNarrowType(newInsertionBits, insValue.getType());
if (failed(newInsertedValueTy))
return failure();
Location loc = op.getLoc();
Value narrowValue = rewriter.createOrFold<arith::TruncIOp>(
- loc, *newInsertedValueTy, insValue->getResult(0));
+ loc, *newInsertedValueTy, insValue.getResult());
Value narrowDest =
rewriter.createOrFold<arith::TruncIOp>(loc, *newVecTy, op.getDest());
return rewriter.create<vector::InsertOp>(loc, narrowValue, narrowDest,
More information about the Mlir-commits
mailing list