[Mlir-commits] [mlir] 7f3b0e5 - [mlir][arith] Add narrowing patterns to commute more vector ops

Jakub Kuderski llvmlistbot at llvm.org
Mon May 1 11:34:26 PDT 2023


Author: Jakub Kuderski
Date: 2023-05-01T14:32:57-04:00
New Revision: 7f3b0e584513611bb1d804892eb269ae45d8e715

URL: https://github.com/llvm/llvm-project/commit/7f3b0e584513611bb1d804892eb269ae45d8e715
DIFF: https://github.com/llvm/llvm-project/commit/7f3b0e584513611bb1d804892eb269ae45d8e715.diff

LOG: [mlir][arith] Add narrowing patterns to commute more vector ops

This commutes the extension (`arith.extsi`, `arith.extui`) over the
following vector ops: `vector.broadcast`, `vector.shape_cast`,
`vector.transpose`, `vector.flat_transpose`.

I focused on these as I saw them getting created by vector unroll
patterns. Maybe except `vector.flat_transpose`.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D149534

Added: 
    

Modified: 
    mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
    mlir/test/Dialect/Arith/int-narrowing.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
index 97164621e45c9..0c7afd9255bcd 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
@@ -249,6 +249,26 @@ using UIToFPPattern = IToFPPattern<arith::UIToFPOp, ExtensionKind::Zero>;
 // Patterns to Commute Extension Ops
 //===----------------------------------------------------------------------===//
 
+struct ExtensionOverBroadcast final : NarrowingPattern<vector::BroadcastOp> {
+  using NarrowingPattern::NarrowingPattern;
+
+  LogicalResult matchAndRewrite(vector::BroadcastOp op,
+                                PatternRewriter &rewriter) const override {
+    FailureOr<ExtensionOp> ext =
+        ExtensionOp::from(op.getSource().getDefiningOp());
+    if (failed(ext))
+      return failure();
+
+    VectorType origTy = op.getResultVectorType();
+    VectorType newTy =
+        origTy.cloneWith(origTy.getShape(), ext->getInElementType());
+    Value newBroadcast =
+        rewriter.create<vector::BroadcastOp>(op.getLoc(), newTy, ext->getIn());
+    ext->recreateAndReplace(rewriter, op, newBroadcast);
+    return success();
+  }
+};
+
 struct ExtensionOverExtract final : NarrowingPattern<vector::ExtractOp> {
   using NarrowingPattern::NarrowingPattern;
 
@@ -421,6 +441,68 @@ struct ExtensionOverInsertStridedSlice final
   }
 };
 
+struct ExtensionOverShapeCast final : NarrowingPattern<vector::ShapeCastOp> {
+  using NarrowingPattern::NarrowingPattern;
+
+  LogicalResult matchAndRewrite(vector::ShapeCastOp op,
+                                PatternRewriter &rewriter) const override {
+    FailureOr<ExtensionOp> ext =
+        ExtensionOp::from(op.getSource().getDefiningOp());
+    if (failed(ext))
+      return failure();
+
+    VectorType origTy = op.getResultVectorType();
+    VectorType newTy =
+        origTy.cloneWith(origTy.getShape(), ext->getInElementType());
+    Value newCast =
+        rewriter.create<vector::ShapeCastOp>(op.getLoc(), newTy, ext->getIn());
+    ext->recreateAndReplace(rewriter, op, newCast);
+    return success();
+  }
+};
+
+struct ExtensionOverTranspose final : NarrowingPattern<vector::TransposeOp> {
+  using NarrowingPattern::NarrowingPattern;
+
+  LogicalResult matchAndRewrite(vector::TransposeOp op,
+                                PatternRewriter &rewriter) const override {
+    FailureOr<ExtensionOp> ext =
+        ExtensionOp::from(op.getVector().getDefiningOp());
+    if (failed(ext))
+      return failure();
+
+    VectorType origTy = op.getResultVectorType();
+    VectorType newTy =
+        origTy.cloneWith(origTy.getShape(), ext->getInElementType());
+    Value newTranspose = rewriter.create<vector::TransposeOp>(
+        op.getLoc(), newTy, ext->getIn(), op.getTransp());
+    ext->recreateAndReplace(rewriter, op, newTranspose);
+    return success();
+  }
+};
+
+struct ExtensionOverFlatTranspose final
+    : NarrowingPattern<vector::FlatTransposeOp> {
+  using NarrowingPattern::NarrowingPattern;
+
+  LogicalResult matchAndRewrite(vector::FlatTransposeOp op,
+                                PatternRewriter &rewriter) const override {
+    FailureOr<ExtensionOp> ext =
+        ExtensionOp::from(op.getMatrix().getDefiningOp());
+    if (failed(ext))
+      return failure();
+
+    VectorType origTy = op.getType();
+    VectorType newTy =
+        origTy.cloneWith(origTy.getShape(), ext->getInElementType());
+    Value newTranspose = rewriter.create<vector::FlatTransposeOp>(
+        op.getLoc(), newTy, ext->getIn(), op.getRowsAttr(),
+        op.getColumnsAttr());
+    ext->recreateAndReplace(rewriter, op, newTranspose);
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // Pass Definitions
 //===----------------------------------------------------------------------===//
@@ -449,9 +531,11 @@ void populateArithIntNarrowingPatterns(
     RewritePatternSet &patterns, const ArithIntNarrowingOptions &options) {
   // Add commute patterns with a higher benefit. This is to expose more
   // optimization opportunities to narrowing patterns.
-  patterns.add<ExtensionOverExtract, ExtensionOverExtractElement,
-               ExtensionOverExtractStridedSlice, ExtensionOverInsert,
-               ExtensionOverInsertElement, ExtensionOverInsertStridedSlice>(
+  patterns.add<ExtensionOverBroadcast, ExtensionOverExtract,
+               ExtensionOverExtractElement, ExtensionOverExtractStridedSlice,
+               ExtensionOverInsert, ExtensionOverInsertElement,
+               ExtensionOverInsertStridedSlice, ExtensionOverShapeCast,
+               ExtensionOverTranspose, ExtensionOverFlatTranspose>(
       patterns.getContext(), options, PatternBenefit(2));
 
   patterns.add<SIToFPPattern, UIToFPPattern>(patterns.getContext(), options);

diff  --git a/mlir/test/Dialect/Arith/int-narrowing.mlir b/mlir/test/Dialect/Arith/int-narrowing.mlir
index 6d5299c2f00da..675a52b5d53e6 100644
--- a/mlir/test/Dialect/Arith/int-narrowing.mlir
+++ b/mlir/test/Dialect/Arith/int-narrowing.mlir
@@ -442,3 +442,91 @@ func.func @extui_over_insert_strided_slice_cst_2d(%a: vector<1x2xi8>) -> vector<
   %e = vector.insert_strided_slice %d, %cst {offsets = [0, 1], strides = [1, 1]} : vector<1x2xi32> into vector<2x3xi32>
   return %e : vector<2x3xi32>
 }
+
+// CHECK-LABEL: func.func @extsi_over_broadcast_3xi16
+// CHECK-SAME:    (%[[ARG:.+]]: i16)
+// CHECK-NEXT:    %[[BCST:.+]] = vector.broadcast %[[ARG]] : i16 to vector<3xi16>
+// CHECK-NEXT:    %[[RET:.+]]  = arith.extsi %[[BCST]] : vector<3xi16> to vector<3xi32>
+// CHECK-NEXT:    return %[[RET]] : vector<3xi32>
+func.func @extsi_over_broadcast_3xi16(%a: i16) -> vector<3xi32> {
+  %b = arith.extsi %a : i16 to i32
+  %r = vector.broadcast %b : i32 to vector<3xi32>
+  return %r : vector<3xi32>
+}
+
+// CHECK-LABEL: func.func @extui_over_broadcast_2x3xi16
+// CHECK-SAME:    (%[[ARG:.+]]: vector<3xi16>)
+// CHECK-NEXT:    %[[BCST:.+]] = vector.broadcast %[[ARG]] : vector<3xi16> to vector<2x3xi16>
+// CHECK-NEXT:    %[[RET:.+]]  = arith.extui %[[BCST]] : vector<2x3xi16> to vector<2x3xi32>
+// CHECK-NEXT:    return %[[RET]] : vector<2x3xi32>
+func.func @extui_over_broadcast_2x3xi16(%a: vector<3xi16>) -> vector<2x3xi32> {
+  %b = arith.extui %a : vector<3xi16> to vector<3xi32>
+  %r = vector.broadcast %b : vector<3xi32> to vector<2x3xi32>
+  return %r : vector<2x3xi32>
+}
+
+// CHECK-LABEL: func.func @extsi_over_shape_cast_2x3xi16
+// CHECK-SAME:    (%[[ARG:.+]]: vector<2x3xi16>)
+// CHECK-NEXT:    %[[CAST:.+]] = vector.shape_cast %[[ARG]] : vector<2x3xi16> to vector<3x2xi16>
+// CHECK-NEXT:    %[[RET:.+]]  = arith.extsi %[[CAST]] : vector<3x2xi16> to vector<3x2xi32>
+// CHECK-NEXT:    return %[[RET]] : vector<3x2xi32>
+func.func @extsi_over_shape_cast_2x3xi16(%a: vector<2x3xi16>) -> vector<3x2xi32> {
+  %b = arith.extsi %a : vector<2x3xi16> to vector<2x3xi32>
+  %r = vector.shape_cast %b : vector<2x3xi32> to vector<3x2xi32>
+  return %r : vector<3x2xi32>
+}
+
+// CHECK-LABEL: func.func @extui_over_shape_cast_5x2x3xi16
+// CHECK-SAME:    (%[[ARG:.+]]: vector<5x2x3xi16>)
+// CHECK-NEXT:    %[[CAST:.+]] = vector.shape_cast %[[ARG]] : vector<5x2x3xi16> to vector<2x3x5xi16>
+// CHECK-NEXT:    %[[RET:.+]]  = arith.extui %[[CAST]] : vector<2x3x5xi16> to vector<2x3x5xi32>
+// CHECK-NEXT:    return %[[RET]] : vector<2x3x5xi32>
+func.func @extui_over_shape_cast_5x2x3xi16(%a: vector<5x2x3xi16>) -> vector<2x3x5xi32> {
+  %b = arith.extui %a : vector<5x2x3xi16> to vector<5x2x3xi32>
+  %r = vector.shape_cast %b : vector<5x2x3xi32> to vector<2x3x5xi32>
+  return %r : vector<2x3x5xi32>
+}
+
+// CHECK-LABEL: func.func @extsi_over_transpose_2x3xi16
+// CHECK-SAME:    (%[[ARG:.+]]: vector<2x3xi16>)
+// CHECK-NEXT:    %[[TRAN:.+]] = vector.transpose %[[ARG]], [1, 0] : vector<2x3xi16> to vector<3x2xi16>
+// CHECK-NEXT:    %[[RET:.+]]  = arith.extsi %[[TRAN]] : vector<3x2xi16> to vector<3x2xi32>
+// CHECK-NEXT:    return %[[RET]] : vector<3x2xi32>
+func.func @extsi_over_transpose_2x3xi16(%a: vector<2x3xi16>) -> vector<3x2xi32> {
+  %b = arith.extsi %a : vector<2x3xi16> to vector<2x3xi32>
+  %r = vector.transpose %b, [1, 0] : vector<2x3xi32> to vector<3x2xi32>
+  return %r : vector<3x2xi32>
+}
+
+// CHECK-LABEL: func.func @extui_over_transpose_5x2x3xi16
+// CHECK-SAME:    (%[[ARG:.+]]: vector<5x2x3xi16>)
+// CHECK-NEXT:    %[[TRAN:.+]] = vector.transpose %[[ARG]], [1, 2, 0] : vector<5x2x3xi16> to vector<2x3x5xi16>
+// CHECK-NEXT:    %[[RET:.+]]  = arith.extui %[[TRAN]] : vector<2x3x5xi16> to vector<2x3x5xi32>
+// CHECK-NEXT:    return %[[RET]] : vector<2x3x5xi32>
+func.func @extui_over_transpose_5x2x3xi16(%a: vector<5x2x3xi16>) -> vector<2x3x5xi32> {
+  %b = arith.extui %a : vector<5x2x3xi16> to vector<5x2x3xi32>
+  %r = vector.transpose %b, [1, 2, 0] : vector<5x2x3xi32> to vector<2x3x5xi32>
+  return %r : vector<2x3x5xi32>
+}
+
+// CHECK-LABEL: func.func @extsi_over_flat_transpose_16xi16
+// CHECK-SAME:    (%[[ARG:.+]]: vector<16xi16>)
+// CHECK-NEXT:    %[[TRAN:.+]] = vector.flat_transpose %[[ARG]] {columns = 4 : i32, rows = 4 : i32} : vector<16xi16> -> vector<16xi16>
+// CHECK-NEXT:    %[[RET:.+]]  = arith.extsi %[[TRAN]] : vector<16xi16> to vector<16xi32>
+// CHECK-NEXT:    return %[[RET]] : vector<16xi32>
+func.func @extsi_over_flat_transpose_16xi16(%a: vector<16xi16>) -> vector<16xi32> {
+  %b = arith.extsi %a : vector<16xi16> to vector<16xi32>
+  %r = vector.flat_transpose %b {columns = 4 : i32, rows = 4 : i32} : vector<16xi32> -> vector<16xi32>
+  return %r : vector<16xi32>
+}
+
+// CHECK-LABEL: func.func @extui_over_flat_transpose_16xi16
+// CHECK-SAME:    (%[[ARG:.+]]: vector<16xi16>)
+// CHECK-NEXT:    %[[TRAN:.+]] = vector.flat_transpose %[[ARG]] {columns = 8 : i32, rows = 2 : i32} : vector<16xi16> -> vector<16xi16>
+// CHECK-NEXT:    %[[RET:.+]]  = arith.extui %[[TRAN]] : vector<16xi16> to vector<16xi32>
+// CHECK-NEXT:    return %[[RET]] : vector<16xi32>
+func.func @extui_over_flat_transpose_16xi16(%a: vector<16xi16>) -> vector<16xi32> {
+  %b = arith.extui %a : vector<16xi16> to vector<16xi32>
+  %r = vector.flat_transpose %b {columns = 8 : i32, rows = 2 : i32} : vector<16xi32> -> vector<16xi32>
+  return %r : vector<16xi32>
+}


        


More information about the Mlir-commits mailing list