[Mlir-commits] [mlir] 09043a2 - [mlir][arith] Add patterns to commute extension over vector extraction
Jakub Kuderski
llvmlistbot at llvm.org
Fri Apr 28 10:50:48 PDT 2023
Author: Jakub Kuderski
Date: 2023-04-28T13:48:50-04:00
New Revision: 09043a26c85dad1ae33a32c0927d467f622de157
URL: https://github.com/llvm/llvm-project/commit/09043a26c85dad1ae33a32c0927d467f622de157
DIFF: https://github.com/llvm/llvm-project/commit/09043a26c85dad1ae33a32c0927d467f622de157.diff
LOG: [mlir][arith] Add patterns to commute extension over vector extraction
This moves zero/sign-extension ops closer to their use and exposes more
narrowing optimization opportunities.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D149233
Added:
Modified:
mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
mlir/test/Dialect/Arith/int-narrowing.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
index e6fe4680d77cb..50b748435afa9 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
@@ -90,6 +90,7 @@ def ArithIntNarrowing : Pass<"arith-int-narrowing"> {
prefers the narrowest available integer bitwidths that are guaranteed to
produce the same results.
}];
+ let dependentDialects = ["vector::VectorDialect"];
let options = [
ListOption<"bitwidthsSupported", "int-bitwidths-supported", "unsigned",
"Integer bitwidths supported">,
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
index e884a19bae1cb..3401a9c05b632 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
@@ -9,16 +9,19 @@
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
#include <cassert>
#include <cstdint>
@@ -143,6 +146,80 @@ struct IToFPPattern final : NarrowingPattern<IToFPOp> {
using SIToFPPattern = IToFPPattern<arith::SIToFPOp, ExtensionKind::Sign>;
using UIToFPPattern = IToFPPattern<arith::UIToFPOp, ExtensionKind::Zero>;
+//===----------------------------------------------------------------------===//
+// Patterns to Commute Extension Ops
+//===----------------------------------------------------------------------===//
+
+struct ExtensionOverExtract final : OpRewritePattern<vector::ExtractOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ExtractOp op,
+ PatternRewriter &rewriter) const override {
+ Operation *def = op.getVector().getDefiningOp();
+ if (!def)
+ return failure();
+
+ return TypeSwitch<Operation *, LogicalResult>(def)
+ .Case<arith::ExtSIOp, arith::ExtUIOp>([&](auto extOp) {
+ Value newExtract = rewriter.create<vector::ExtractOp>(
+ op.getLoc(), extOp.getIn(), op.getPosition());
+ rewriter.replaceOpWithNewOp<decltype(extOp)>(op, op.getType(),
+ newExtract);
+ return success();
+ })
+ .Default(failure());
+ }
+};
+
+struct ExtensionOverExtractElement final
+ : OpRewritePattern<vector::ExtractElementOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ExtractElementOp op,
+ PatternRewriter &rewriter) const override {
+ Operation *def = op.getVector().getDefiningOp();
+ if (!def)
+ return failure();
+
+ return TypeSwitch<Operation *, LogicalResult>(def)
+ .Case<arith::ExtSIOp, arith::ExtUIOp>([&](auto extOp) {
+ Value newExtract = rewriter.create<vector::ExtractElementOp>(
+ op.getLoc(), extOp.getIn(), op.getPosition());
+ rewriter.replaceOpWithNewOp<decltype(extOp)>(op, op.getType(),
+ newExtract);
+ return success();
+ })
+ .Default(failure());
+ }
+};
+
+struct ExtensionOverExtractStridedSlice final
+ : OpRewritePattern<vector::ExtractStridedSliceOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp op,
+ PatternRewriter &rewriter) const override {
+ Operation *def = op.getVector().getDefiningOp();
+ if (!def)
+ return failure();
+
+ return TypeSwitch<Operation *, LogicalResult>(def)
+ .Case<arith::ExtSIOp, arith::ExtUIOp>([&](auto extOp) {
+ VectorType origTy = op.getType();
+ Type inElemTy =
+ cast<VectorType>(extOp.getIn().getType()).getElementType();
+ VectorType extractTy = origTy.cloneWith(origTy.getShape(), inElemTy);
+ Value newExtract = rewriter.create<vector::ExtractStridedSliceOp>(
+ op.getLoc(), extractTy, extOp.getIn(), op.getOffsets(),
+ op.getSizes(), op.getStrides());
+ rewriter.replaceOpWithNewOp<decltype(extOp)>(op, op.getType(),
+ newExtract);
+ return success();
+ })
+ .Default(failure());
+ }
+};
+
//===----------------------------------------------------------------------===//
// Pass Definitions
//===----------------------------------------------------------------------===//
@@ -169,6 +246,12 @@ struct ArithIntNarrowingPass final
void populateArithIntNarrowingPatterns(
RewritePatternSet &patterns, const ArithIntNarrowingOptions &options) {
+ // Add commute patterns with a higher benefit. This is to expose more
+ // optimization opportunities to narrowing patterns.
+ patterns.add<ExtensionOverExtract, ExtensionOverExtractElement,
+ ExtensionOverExtractStridedSlice>(patterns.getContext(),
+ PatternBenefit(2));
+
patterns.add<SIToFPPattern, UIToFPPattern>(patterns.getContext(), options);
}
diff --git a/mlir/test/Dialect/Arith/int-narrowing.mlir b/mlir/test/Dialect/Arith/int-narrowing.mlir
index 21d5ab774c87b..f1290e552fd77 100644
--- a/mlir/test/Dialect/Arith/int-narrowing.mlir
+++ b/mlir/test/Dialect/Arith/int-narrowing.mlir
@@ -1,6 +1,10 @@
// RUN: mlir-opt --arith-int-narrowing="int-bitwidths-supported=1,8,16,32" \
// RUN: --verify-diagnostics %s | FileCheck %s
+//===----------------------------------------------------------------------===//
+// arith.*itofp
+//===----------------------------------------------------------------------===//
+
// CHECK-LABEL: func.func @sitofp_extsi_i16
// CHECK-SAME: (%[[ARG:.+]]: i16)
// CHECK-NEXT: %[[RET:.+]] = arith.sitofp %[[ARG]] : i16 to f16
@@ -131,3 +135,103 @@ func.func @uitofp_extsi_i16(%a: i16) -> f16 {
%f = arith.uitofp %b : i32 to f16
return %f : f16
}
+
+//===----------------------------------------------------------------------===//
+// Commute Extension over Vector Ops
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func.func @extsi_over_extract_3xi16
+// CHECK-SAME: (%[[ARG:.+]]: vector<3xi16>)
+// CHECK-NEXT: %[[EXTR:.+]] = vector.extract %[[ARG]][1] : vector<3xi16>
+// CHECK-NEXT: %[[RET:.+]] = arith.sitofp %[[EXTR]] : i16 to f16
+// CHECK-NEXT: return %[[RET]] : f16
+func.func @extsi_over_extract_3xi16(%a: vector<3xi16>) -> f16 {
+ %b = arith.extsi %a : vector<3xi16> to vector<3xi32>
+ %c = vector.extract %b[1] : vector<3xi32>
+ %f = arith.sitofp %c : i32 to f16
+ return %f : f16
+}
+
+// CHECK-LABEL: func.func @extui_over_extract_3xi16
+// CHECK-SAME: (%[[ARG:.+]]: vector<3xi16>)
+// CHECK-NEXT: %[[EXTR:.+]] = vector.extract %[[ARG]][1] : vector<3xi16>
+// CHECK-NEXT: %[[RET:.+]] = arith.uitofp %[[EXTR]] : i16 to f16
+// CHECK-NEXT: return %[[RET]] : f16
+func.func @extui_over_extract_3xi16(%a: vector<3xi16>) -> f16 {
+ %b = arith.extui %a : vector<3xi16> to vector<3xi32>
+ %c = vector.extract %b[1] : vector<3xi32>
+ %f = arith.uitofp %c : i32 to f16
+ return %f : f16
+}
+
+// CHECK-LABEL: func.func @extsi_over_extractelement_3xi16
+// CHECK-SAME: (%[[ARG:.+]]: vector<3xi16>, %[[POS:.+]]: i32)
+// CHECK-NEXT: %[[EXTR:.+]] = vector.extractelement %[[ARG]][%[[POS]] : i32] : vector<3xi16>
+// CHECK-NEXT: %[[RET:.+]] = arith.sitofp %[[EXTR]] : i16 to f16
+// CHECK-NEXT: return %[[RET]] : f16
+func.func @extsi_over_extractelement_3xi16(%a: vector<3xi16>, %pos: i32) -> f16 {
+ %b = arith.extsi %a : vector<3xi16> to vector<3xi32>
+ %c = vector.extractelement %b[%pos : i32] : vector<3xi32>
+ %f = arith.sitofp %c : i32 to f16
+ return %f : f16
+}
+
+// CHECK-LABEL: func.func @extui_over_extractelement_3xi16
+// CHECK-SAME: (%[[ARG:.+]]: vector<3xi16>, %[[POS:.+]]: i32)
+// CHECK-NEXT: %[[EXTR:.+]] = vector.extractelement %[[ARG]][%[[POS]] : i32] : vector<3xi16>
+// CHECK-NEXT: %[[RET:.+]] = arith.uitofp %[[EXTR]] : i16 to f16
+// CHECK-NEXT: return %[[RET]] : f16
+func.func @extui_over_extractelement_3xi16(%a: vector<3xi16>, %pos: i32) -> f16 {
+ %b = arith.extui %a : vector<3xi16> to vector<3xi32>
+ %c = vector.extractelement %b[%pos : i32] : vector<3xi32>
+ %f = arith.uitofp %c : i32 to f16
+ return %f : f16
+}
+
+// CHECK-LABEL: func.func @extsi_over_extract_strided_slice_1d
+// CHECK-SAME: (%[[ARG:.+]]: vector<3xi16>)
+// CHECK-NEXT: %[[EXTR:.+]] = vector.extract_strided_slice %[[ARG]] {offsets = [1], sizes = [2], strides = [1]} : vector<3xi16> to vector<2xi16>
+// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[EXTR]] : vector<2xi16> to vector<2xi32>
+// CHECK-NEXT: return %[[RET]] : vector<2xi32>
+func.func @extsi_over_extract_strided_slice_1d(%a: vector<3xi16>) -> vector<2xi32> {
+ %b = arith.extsi %a : vector<3xi16> to vector<3xi32>
+ %c = vector.extract_strided_slice %b
+ {offsets = [1], sizes = [2], strides = [1]} : vector<3xi32> to vector<2xi32>
+ return %c : vector<2xi32>
+}
+
+// CHECK-LABEL: func.func @extui_over_extract_strided_slice_1d
+// CHECK-SAME: (%[[ARG:.+]]: vector<3xi16>)
+// CHECK-NEXT: %[[EXTR:.+]] = vector.extract_strided_slice %[[ARG]] {offsets = [1], sizes = [2], strides = [1]} : vector<3xi16> to vector<2xi16>
+// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[EXTR]] : vector<2xi16> to vector<2xi32>
+// CHECK-NEXT: return %[[RET]] : vector<2xi32>
+func.func @extui_over_extract_strided_slice_1d(%a: vector<3xi16>) -> vector<2xi32> {
+ %b = arith.extui %a : vector<3xi16> to vector<3xi32>
+ %c = vector.extract_strided_slice %b
+ {offsets = [1], sizes = [2], strides = [1]} : vector<3xi32> to vector<2xi32>
+ return %c : vector<2xi32>
+}
+
+// CHECK-LABEL: func.func @extsi_over_extract_strided_slice_2d
+// CHECK-SAME: (%[[ARG:.+]]: vector<2x3xi16>)
+// CHECK-NEXT: %[[EXTR:.+]] = vector.extract_strided_slice %arg0 {offsets = [1, 1], sizes = [1, 2], strides = [1, 1]} : vector<2x3xi16> to vector<1x2xi16>
+// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[EXTR]] : vector<1x2xi16> to vector<1x2xi32>
+// CHECK-NEXT: return %[[RET]] : vector<1x2xi32>
+func.func @extsi_over_extract_strided_slice_2d(%a: vector<2x3xi16>) -> vector<1x2xi32> {
+ %b = arith.extsi %a : vector<2x3xi16> to vector<2x3xi32>
+ %c = vector.extract_strided_slice %b
+ {offsets = [1, 1], sizes = [1, 2], strides = [1, 1]} : vector<2x3xi32> to vector<1x2xi32>
+ return %c : vector<1x2xi32>
+}
+
+// CHECK-LABEL: func.func @extui_over_extract_strided_slice_2d
+// CHECK-SAME: (%[[ARG:.+]]: vector<2x3xi16>)
+// CHECK-NEXT: %[[EXTR:.+]] = vector.extract_strided_slice %arg0 {offsets = [1, 1], sizes = [1, 2], strides = [1, 1]} : vector<2x3xi16> to vector<1x2xi16>
+// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[EXTR]] : vector<1x2xi16> to vector<1x2xi32>
+// CHECK-NEXT: return %[[RET]] : vector<1x2xi32>
+func.func @extui_over_extract_strided_slice_2d(%a: vector<2x3xi16>) -> vector<1x2xi32> {
+ %b = arith.extui %a : vector<2x3xi16> to vector<2x3xi32>
+ %c = vector.extract_strided_slice %b
+ {offsets = [1, 1], sizes = [1, 2], strides = [1, 1]} : vector<2x3xi32> to vector<1x2xi32>
+ return %c : vector<1x2xi32>
+}
More information about the Mlir-commits
mailing list