[Mlir-commits] [mlir] 3ff8708 - [mlir][arith] Add narrowing patterns for other insertion ops

Jakub Kuderski llvmlistbot at llvm.org
Mon May 1 11:30:34 PDT 2023


Author: Jakub Kuderski
Date: 2023-05-01T14:29:02-04:00
New Revision: 3ff870881f5f0d3d08753efd558ac5f05d04a574

URL: https://github.com/llvm/llvm-project/commit/3ff870881f5f0d3d08753efd558ac5f05d04a574
DIFF: https://github.com/llvm/llvm-project/commit/3ff870881f5f0d3d08753efd558ac5f05d04a574.diff

LOG: [mlir][arith] Add narrowing patterns for other insertion ops

Allow to commute extension ops over `vector.insertelement` and
`vector.insert_strided_slice`.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D149509

Added: 
    

Modified: 
    mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
    mlir/test/Dialect/Arith/int-narrowing.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
index c515824a8e04d..97164621e45c9 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
@@ -306,27 +306,35 @@ struct ExtensionOverExtractStridedSlice final
   }
 };
 
-struct ExtensionOverInsert final : NarrowingPattern<vector::InsertOp> {
-  using NarrowingPattern::NarrowingPattern;
-
-  LogicalResult matchAndRewrite(vector::InsertOp op,
-                                PatternRewriter &rewriter) const override {
+/// Base pattern for `vector.insert` narrowing patterns.
+template <typename InsertionOp>
+struct ExtensionOverInsertionPattern : NarrowingPattern<InsertionOp> {
+  using NarrowingPattern<InsertionOp>::NarrowingPattern;
+
+  /// Derived classes must provide a function to create the matching insertion
+  /// op based on the original op and new arguments.
+  virtual InsertionOp createInsertionOp(PatternRewriter &rewriter,
+                                        InsertionOp origInsert,
+                                        Value narrowValue,
+                                        Value narrowDest) const = 0;
+
+  LogicalResult matchAndRewrite(InsertionOp op,
+                                PatternRewriter &rewriter) const final {
     FailureOr<ExtensionOp> ext =
         ExtensionOp::from(op.getSource().getDefiningOp());
     if (failed(ext))
       return failure();
 
-    FailureOr<vector::InsertOp> newInsert =
-        createNarrowInsert(op, rewriter, *ext);
+    FailureOr<InsertionOp> newInsert = createNarrowInsert(op, rewriter, *ext);
     if (failed(newInsert))
       return failure();
     ext->recreateAndReplace(rewriter, op, *newInsert);
     return success();
   }
 
-  FailureOr<vector::InsertOp> createNarrowInsert(vector::InsertOp op,
-                                                 PatternRewriter &rewriter,
-                                                 ExtensionOp insValue) const {
+  FailureOr<InsertionOp> createNarrowInsert(InsertionOp op,
+                                            PatternRewriter &rewriter,
+                                            ExtensionOp insValue) const {
     // Calculate the operand and result bitwidths. We can only apply narrowing
     // when the inserted source value and destination vector require fewer bits
     // than the result. Because the source and destination may have 
diff erent
@@ -337,6 +345,8 @@ struct ExtensionOverInsert final : NarrowingPattern<vector::InsertOp> {
     if (failed(origBitsRequired))
       return failure();
 
+    // TODO: We could relax this check by disregarding bitwidth requirements of
+    // elements that we know will be replaced by the insertion.
     FailureOr<unsigned> destBitsRequired =
         calculateBitsRequired(op.getDest(), insValue.getKind());
     if (failed(destBitsRequired) || *destBitsRequired >= *origBitsRequired)
@@ -352,12 +362,13 @@ struct ExtensionOverInsert final : NarrowingPattern<vector::InsertOp> {
     // both the source and the destination values.
     unsigned newInsertionBits =
         std::max(*destBitsRequired, *insertedBitsRequired);
-    FailureOr<Type> newVecTy = getNarrowType(newInsertionBits, op.getType());
+    FailureOr<Type> newVecTy =
+        this->getNarrowType(newInsertionBits, op.getType());
     if (failed(newVecTy) || *newVecTy == op.getType())
       return failure();
 
     FailureOr<Type> newInsertedValueTy =
-        getNarrowType(newInsertionBits, insValue.getType());
+        this->getNarrowType(newInsertionBits, insValue.getType());
     if (failed(newInsertedValueTy))
       return failure();
 
@@ -366,8 +377,47 @@ struct ExtensionOverInsert final : NarrowingPattern<vector::InsertOp> {
         loc, *newInsertedValueTy, insValue.getResult());
     Value narrowDest =
         rewriter.createOrFold<arith::TruncIOp>(loc, *newVecTy, op.getDest());
-    return rewriter.create<vector::InsertOp>(loc, narrowValue, narrowDest,
-                                             op.getPosition());
+    return createInsertionOp(rewriter, op, narrowValue, narrowDest);
+  }
+};
+
+struct ExtensionOverInsert final
+    : ExtensionOverInsertionPattern<vector::InsertOp> {
+  using ExtensionOverInsertionPattern::ExtensionOverInsertionPattern;
+
+  vector::InsertOp createInsertionOp(PatternRewriter &rewriter,
+                                     vector::InsertOp origInsert,
+                                     Value narrowValue,
+                                     Value narrowDest) const override {
+    return rewriter.create<vector::InsertOp>(
+        origInsert.getLoc(), narrowValue, narrowDest, origInsert.getPosition());
+  }
+};
+
+struct ExtensionOverInsertElement final
+    : ExtensionOverInsertionPattern<vector::InsertElementOp> {
+  using ExtensionOverInsertionPattern::ExtensionOverInsertionPattern;
+
+  vector::InsertElementOp createInsertionOp(PatternRewriter &rewriter,
+                                            vector::InsertElementOp origInsert,
+                                            Value narrowValue,
+                                            Value narrowDest) const override {
+    return rewriter.create<vector::InsertElementOp>(
+        origInsert.getLoc(), narrowValue, narrowDest, origInsert.getPosition());
+  }
+};
+
+struct ExtensionOverInsertStridedSlice final
+    : ExtensionOverInsertionPattern<vector::InsertStridedSliceOp> {
+  using ExtensionOverInsertionPattern::ExtensionOverInsertionPattern;
+
+  vector::InsertStridedSliceOp
+  createInsertionOp(PatternRewriter &rewriter,
+                    vector::InsertStridedSliceOp origInsert, Value narrowValue,
+                    Value narrowDest) const override {
+    return rewriter.create<vector::InsertStridedSliceOp>(
+        origInsert.getLoc(), narrowValue, narrowDest, origInsert.getOffsets(),
+        origInsert.getStrides());
   }
 };
 
@@ -400,7 +450,8 @@ void populateArithIntNarrowingPatterns(
   // Add commute patterns with a higher benefit. This is to expose more
   // optimization opportunities to narrowing patterns.
   patterns.add<ExtensionOverExtract, ExtensionOverExtractElement,
-               ExtensionOverExtractStridedSlice, ExtensionOverInsert>(
+               ExtensionOverExtractStridedSlice, ExtensionOverInsert,
+               ExtensionOverInsertElement, ExtensionOverInsertStridedSlice>(
       patterns.getContext(), options, PatternBenefit(2));
 
   patterns.add<SIToFPPattern, UIToFPPattern>(patterns.getContext(), options);

diff  --git a/mlir/test/Dialect/Arith/int-narrowing.mlir b/mlir/test/Dialect/Arith/int-narrowing.mlir
index d98e03d93b030..6d5299c2f00da 100644
--- a/mlir/test/Dialect/Arith/int-narrowing.mlir
+++ b/mlir/test/Dialect/Arith/int-narrowing.mlir
@@ -328,3 +328,117 @@ func.func @extui_over_insert_3xi16_cst_i16(%a: i8) -> vector<3xi32> {
   %e = vector.insert %d, %cst [1] : i32 into vector<3xi32>
   return %e : vector<3xi32>
 }
+
+// CHECK-LABEL: func.func @extsi_over_insertelement_3xi16
+// CHECK-SAME:    (%[[ARG0:.+]]: vector<3xi16>, %[[ARG1:.+]]: i16, %[[POS:.+]]: i32)
+// CHECK-NEXT:    %[[INS:.+]] = vector.insertelement %[[ARG1]], %[[ARG0]][%[[POS]] : i32] : vector<3xi16>
+// CHECK-NEXT:    %[[RET:.+]] = arith.extsi %[[INS]] : vector<3xi16> to vector<3xi32>
+// CHECK-NEXT:    return %[[RET]] : vector<3xi32>
+func.func @extsi_over_insertelement_3xi16(%a: vector<3xi16>, %b: i16, %pos: i32) -> vector<3xi32> {
+  %c = arith.extsi %a : vector<3xi16> to vector<3xi32>
+  %d = arith.extsi %b : i16 to i32
+  %e = vector.insertelement %d, %c[%pos : i32] : vector<3xi32>
+  return %e : vector<3xi32>
+}
+
+// CHECK-LABEL: func.func @extui_over_insertelement_3xi16
+// CHECK-SAME:    (%[[ARG0:.+]]: vector<3xi16>, %[[ARG1:.+]]: i16, %[[POS:.+]]: i32)
+// CHECK-NEXT:    %[[INS:.+]] = vector.insertelement %[[ARG1]], %[[ARG0]][%[[POS]] : i32] : vector<3xi16>
+// CHECK-NEXT:    %[[RET:.+]] = arith.extui %[[INS]] : vector<3xi16> to vector<3xi32>
+// CHECK-NEXT:    return %[[RET]] : vector<3xi32>
+func.func @extui_over_insertelement_3xi16(%a: vector<3xi16>, %b: i16, %pos: i32) -> vector<3xi32> {
+  %c = arith.extui %a : vector<3xi16> to vector<3xi32>
+  %d = arith.extui %b : i16 to i32
+  %e = vector.insertelement %d, %c[%pos : i32] : vector<3xi32>
+  return %e : vector<3xi32>
+}
+
+// CHECK-LABEL: func.func @extsi_over_insertelement_3xi16_cst_i16
+// CHECK-SAME:    (%[[ARG:.+]]: i8, %[[POS:.+]]: i32)
+// CHECK-NEXT:    %[[CST:.+]]  = arith.constant dense<[-1, 128, 0]> : vector<3xi16>
+// CHECK-NEXT:    %[[SRCE:.+]] = arith.extsi %[[ARG]] : i8 to i32
+// CHECK-NEXT:    %[[SRCT:.+]] = arith.trunci %[[SRCE]] : i32 to i16
+// CHECK-NEXT:    %[[INS:.+]] = vector.insertelement %[[SRCT]], %[[CST]][%[[POS]] : i32] : vector<3xi16>
+// CHECK-NEXT:    %[[RET:.+]]  = arith.extsi %[[INS]] : vector<3xi16> to vector<3xi32>
+// CHECK-NEXT:    return %[[RET]] : vector<3xi32>
+func.func @extsi_over_insertelement_3xi16_cst_i16(%a: i8, %pos: i32) -> vector<3xi32> {
+  %cst = arith.constant dense<[-1, 128, 0]> : vector<3xi32>
+  %d = arith.extsi %a : i8 to i32
+  %e = vector.insertelement %d, %cst[%pos : i32] : vector<3xi32>
+  return %e : vector<3xi32>
+}
+
+// CHECK-LABEL: func.func @extui_over_insertelement_3xi16_cst_i16
+// CHECK-SAME:    (%[[ARG:.+]]: i8, %[[POS:.+]]: i32)
+// CHECK-NEXT:    %[[CST:.+]]  = arith.constant dense<[1, 256, 0]> : vector<3xi16>
+// CHECK-NEXT:    %[[SRCE:.+]] = arith.extui %[[ARG]] : i8 to i32
+// CHECK-NEXT:    %[[SRCT:.+]] = arith.trunci %[[SRCE]] : i32 to i16
+// CHECK-NEXT:    %[[INS:.+]] = vector.insertelement %[[SRCT]], %[[CST]][%[[POS]] : i32] : vector<3xi16>
+// CHECK-NEXT:    %[[RET:.+]]  = arith.extui %[[INS]] : vector<3xi16> to vector<3xi32>
+// CHECK-NEXT:    return %[[RET]] : vector<3xi32>
+func.func @extui_over_insertelement_3xi16_cst_i16(%a: i8, %pos: i32) -> vector<3xi32> {
+  %cst = arith.constant dense<[1, 256, 0]> : vector<3xi32>
+  %d = arith.extui %a : i8 to i32
+  %e = vector.insertelement %d, %cst[%pos : i32] : vector<3xi32>
+  return %e : vector<3xi32>
+}
+
+// CHECK-LABEL: func.func @extsi_over_insert_strided_slice_1d
+// CHECK-SAME:    (%[[ARG0:.+]]: vector<3xi16>, %[[ARG1:.+]]: vector<2xi16>)
+// CHECK-NEXT:    %[[INS:.+]] = vector.insert_strided_slice %[[ARG1]], %[[ARG0]]
+// CHECK-SAME:                    {offsets = [1], strides = [1]} : vector<2xi16> into vector<3xi16>
+// CHECK-NEXT:    %[[RET:.+]] = arith.extsi %[[INS]] : vector<3xi16> to vector<3xi32>
+// CHECK-NEXT:    return %[[RET]] : vector<3xi32>
+func.func @extsi_over_insert_strided_slice_1d(%a: vector<3xi16>, %b: vector<2xi16>) -> vector<3xi32> {
+  %c = arith.extsi %a : vector<3xi16> to vector<3xi32>
+  %d = arith.extsi %b : vector<2xi16> to vector<2xi32>
+  %e = vector.insert_strided_slice %d, %c {offsets = [1], strides = [1]} : vector<2xi32> into vector<3xi32>
+  return %e : vector<3xi32>
+}
+
+// CHECK-LABEL: func.func @extui_over_insert_strided_slice_1d
+// CHECK-SAME:    (%[[ARG0:.+]]: vector<3xi16>, %[[ARG1:.+]]: vector<2xi16>)
+// CHECK-NEXT:    %[[INS:.+]] = vector.insert_strided_slice %[[ARG1]], %[[ARG0]]
+// CHECK-SAME:                    {offsets = [1], strides = [1]} : vector<2xi16> into vector<3xi16>
+// CHECK-NEXT:    %[[RET:.+]] = arith.extui %[[INS]] : vector<3xi16> to vector<3xi32>
+// CHECK-NEXT:    return %[[RET]] : vector<3xi32>
+func.func @extui_over_insert_strided_slice_1d(%a: vector<3xi16>, %b: vector<2xi16>) -> vector<3xi32> {
+  %c = arith.extui %a : vector<3xi16> to vector<3xi32>
+  %d = arith.extui %b : vector<2xi16> to vector<2xi32>
+  %e = vector.insert_strided_slice %d, %c {offsets = [1], strides = [1]} : vector<2xi32> into vector<3xi32>
+  return %e : vector<3xi32>
+}
+
+// CHECK-LABEL: func.func @extsi_over_insert_strided_slice_cst_2d
+// CHECK-SAME:    (%[[ARG:.+]]: vector<1x2xi8>)
+// CHECK-NEXT:    %[[CST:.+]]  = arith.constant
+// CHECK-SAME{LITERAL}:            dense<[[-1, 128, 0], [-129, 42, 1337]]> : vector<2x3xi16>
+// CHECK-NEXT:    %[[SRCE:.+]] = arith.extsi %[[ARG]] : vector<1x2xi8> to vector<1x2xi32>
+// CHECK-NEXT:    %[[SRCT:.+]] = arith.trunci %[[SRCE]] : vector<1x2xi32> to vector<1x2xi16>
+// CHECK-NEXT:    %[[INS:.+]] = vector.insert_strided_slice %[[SRCT]], %[[CST]]
+// CHECK-SAME:                    {offsets = [0, 1], strides = [1, 1]} : vector<1x2xi16> into vector<2x3xi16>
+// CHECK-NEXT:    %[[RET:.+]]  = arith.extsi %[[INS]] : vector<2x3xi16> to vector<2x3xi32>
+// CHECK-NEXT:    return %[[RET]] : vector<2x3xi32>
+func.func @extsi_over_insert_strided_slice_cst_2d(%a: vector<1x2xi8>) -> vector<2x3xi32> {
+  %cst = arith.constant dense<[[-1, 128, 0], [-129, 42, 1337]]> : vector<2x3xi32>
+  %d = arith.extsi %a : vector<1x2xi8> to vector<1x2xi32>
+  %e = vector.insert_strided_slice %d, %cst {offsets = [0, 1], strides = [1, 1]} : vector<1x2xi32> into vector<2x3xi32>
+  return %e : vector<2x3xi32>
+}
+
+// CHECK-LABEL: func.func @extui_over_insert_strided_slice_cst_2d
+// CHECK-SAME:    (%[[ARG:.+]]: vector<1x2xi8>)
+// CHECK-NEXT:    %[[CST:.+]]  = arith.constant
+// CHECK-SAME{LITERAL}:            dense<[[1, 128, 0], [256, 42, 1337]]> : vector<2x3xi16>
+// CHECK-NEXT:    %[[SRCE:.+]] = arith.extui %[[ARG]] : vector<1x2xi8> to vector<1x2xi32>
+// CHECK-NEXT:    %[[SRCT:.+]] = arith.trunci %[[SRCE]] : vector<1x2xi32> to vector<1x2xi16>
+// CHECK-NEXT:    %[[INS:.+]] = vector.insert_strided_slice %[[SRCT]], %[[CST]]
+// CHECK-SAME:                    {offsets = [0, 1], strides = [1, 1]} : vector<1x2xi16> into vector<2x3xi16>
+// CHECK-NEXT:    %[[RET:.+]]  = arith.extui %[[INS]] : vector<2x3xi16> to vector<2x3xi32>
+// CHECK-NEXT:    return %[[RET]] : vector<2x3xi32>
+func.func @extui_over_insert_strided_slice_cst_2d(%a: vector<1x2xi8>) -> vector<2x3xi32> {
+  %cst = arith.constant dense<[[1, 128, 0], [256, 42, 1337]]> : vector<2x3xi32>
+  %d = arith.extui %a : vector<1x2xi8> to vector<1x2xi32>
+  %e = vector.insert_strided_slice %d, %cst {offsets = [0, 1], strides = [1, 1]} : vector<1x2xi32> into vector<2x3xi32>
+  return %e : vector<2x3xi32>
+}


        


More information about the Mlir-commits mailing list