[Mlir-commits] [mlir] 33017e5 - [mlir][arith] Add narrowing pattern to commute extension over insertion

Jakub Kuderski llvmlistbot at llvm.org
Fri Apr 28 13:19:12 PDT 2023


Author: Jakub Kuderski
Date: 2023-04-28T16:17:44-04:00
New Revision: 33017e5a3fa2c3194522565cd0e106a931b072b3

URL: https://github.com/llvm/llvm-project/commit/33017e5a3fa2c3194522565cd0e106a931b072b3
DIFF: https://github.com/llvm/llvm-project/commit/33017e5a3fa2c3194522565cd0e106a931b072b3.diff

LOG: [mlir][arith] Add narrowing pattern to commute extension over insertion

This enabled more optimization opportunities by moving
zero/sign-extension closer to the use.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D149282

Added: 
    

Modified: 
    mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
    mlir/test/Dialect/Arith/int-narrowing.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
index 3401a9c05b632..639b19b0a5d8a 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
@@ -70,7 +70,7 @@ struct NarrowingPattern : OpRewritePattern<SourceOp> {
     if (!isa<IntegerType>(elemTy))
       return failure();
 
-    auto newElemTy = IntegerType::get(origTy.getContext(), bitsRequired);
+    auto newElemTy = IntegerType::get(origTy.getContext(), *bestBitwidth);
     if (newElemTy == elemTy)
       return failure();
 
@@ -100,11 +100,58 @@ FailureOr<unsigned> calculateBitsRequired(Type type) {
 
 enum class ExtensionKind { Sign, Zero };
 
+ExtensionKind getExtensionKind(Operation *op) {
+  assert(op);
+  assert((isa<arith::ExtSIOp, arith::ExtUIOp>(op)) && "Not an extension op");
+  return isa<arith::ExtSIOp>(op) ? ExtensionKind::Sign : ExtensionKind::Zero;
+}
+
+/// Returns the integer bitwidth required to represent `value`.
+unsigned calculateBitsRequired(const APInt &value,
+                               ExtensionKind lookThroughExtension) {
+  // For unsigned values, we only need the active bits. As a special case, zero
+  // requires one bit.
+  if (lookThroughExtension == ExtensionKind::Zero)
+    return std::max(value.getActiveBits(), 1u);
+
+  // If a signed value is nonnegative, we need one extra bit for the sign.
+  if (value.isNonNegative())
+    return value.getActiveBits() + 1;
+
+  // For the signed min, we need all the bits.
+  if (value.isMinSignedValue())
+    return value.getBitWidth();
+
+  // For negative values, we need all the non-sign bits and one extra bit for
+  // the sign.
+  return value.getBitWidth() - value.getNumSignBits() + 1;
+}
+
 /// Returns the integer bitwidth required to represent `value`.
 /// Looks through either sign- or zero-extension as specified by
 /// `lookThroughExtension`.
 FailureOr<unsigned> calculateBitsRequired(Value value,
                                           ExtensionKind lookThroughExtension) {
+  // Handle constants.
+  if (TypedAttr attr; matchPattern(value, m_Constant(&attr))) {
+    if (auto intAttr = dyn_cast<IntegerAttr>(attr))
+      return calculateBitsRequired(intAttr.getValue(), lookThroughExtension);
+
+    if (auto elemsAttr = dyn_cast<DenseElementsAttr>(attr)) {
+      if (elemsAttr.getElementType().isIntOrIndex()) {
+        if (elemsAttr.isSplat())
+          return calculateBitsRequired(elemsAttr.getSplatValue<APInt>(),
+                                       lookThroughExtension);
+
+        unsigned maxBits = 1;
+        for (const APInt &elemValue : elemsAttr.getValues<APInt>())
+          maxBits = std::max(
+              maxBits, calculateBitsRequired(elemValue, lookThroughExtension));
+        return maxBits;
+      }
+    }
+  }
+
   if (lookThroughExtension == ExtensionKind::Sign) {
     if (auto sext = value.getDefiningOp<arith::ExtSIOp>())
       return calculateBitsRequired(sext.getIn().getType());
@@ -150,8 +197,8 @@ using UIToFPPattern = IToFPPattern<arith::UIToFPOp, ExtensionKind::Zero>;
 // Patterns to Commute Extension Ops
 //===----------------------------------------------------------------------===//
 
-struct ExtensionOverExtract final : OpRewritePattern<vector::ExtractOp> {
-  using OpRewritePattern::OpRewritePattern;
+struct ExtensionOverExtract final : NarrowingPattern<vector::ExtractOp> {
+  using NarrowingPattern::NarrowingPattern;
 
   LogicalResult matchAndRewrite(vector::ExtractOp op,
                                 PatternRewriter &rewriter) const override {
@@ -172,8 +219,8 @@ struct ExtensionOverExtract final : OpRewritePattern<vector::ExtractOp> {
 };
 
 struct ExtensionOverExtractElement final
-    : OpRewritePattern<vector::ExtractElementOp> {
-  using OpRewritePattern::OpRewritePattern;
+    : NarrowingPattern<vector::ExtractElementOp> {
+  using NarrowingPattern::NarrowingPattern;
 
   LogicalResult matchAndRewrite(vector::ExtractElementOp op,
                                 PatternRewriter &rewriter) const override {
@@ -194,8 +241,8 @@ struct ExtensionOverExtractElement final
 };
 
 struct ExtensionOverExtractStridedSlice final
-    : OpRewritePattern<vector::ExtractStridedSliceOp> {
-  using OpRewritePattern::OpRewritePattern;
+    : NarrowingPattern<vector::ExtractStridedSliceOp> {
+  using NarrowingPattern::NarrowingPattern;
 
   LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp op,
                                 PatternRewriter &rewriter) const override {
@@ -220,6 +267,80 @@ struct ExtensionOverExtractStridedSlice final
   }
 };
 
+struct ExtensionOverInsert final : NarrowingPattern<vector::InsertOp> {
+  using NarrowingPattern::NarrowingPattern;
+
+  LogicalResult matchAndRewrite(vector::InsertOp op,
+                                PatternRewriter &rewriter) const override {
+    Operation *def = op.getSource().getDefiningOp();
+    if (!def)
+      return failure();
+
+    return TypeSwitch<Operation *, LogicalResult>(def)
+        .Case<arith::ExtSIOp, arith::ExtUIOp>([&](auto extOp) {
+          // Rewrite the insertion in terms of narrower operands
+          // and later extend the result to the original bitwidth.
+          FailureOr<vector::InsertOp> newInsert =
+              createNarrowInsert(op, rewriter, extOp);
+          if (failed(newInsert))
+            return failure();
+          rewriter.replaceOpWithNewOp<decltype(extOp)>(op, op.getType(),
+                                                       *newInsert);
+          return success();
+        })
+        .Default(failure());
+  }
+
+  FailureOr<vector::InsertOp> createNarrowInsert(vector::InsertOp op,
+                                                 PatternRewriter &rewriter,
+                                                 Operation *insValue) const {
+    assert((isa<arith::ExtSIOp, arith::ExtUIOp>(insValue)));
+
+    // Calculate the operand and result bitwidths. We can only apply narrowing
+    // when the inserted source value and destination vector require fewer bits
+    // than the result. Because the source and destination may have 
diff erent
+    // bitwidths requirements, we have to find the common narrow bitwidth that
+    // is greater equal to the operand bitwidth requirements and still narrower
+    // than the result.
+    FailureOr<unsigned> origBitsRequired = calculateBitsRequired(op.getType());
+    if (failed(origBitsRequired))
+      return failure();
+
+    ExtensionKind kind = getExtensionKind(insValue);
+    FailureOr<unsigned> destBitsRequired =
+        calculateBitsRequired(op.getDest(), kind);
+    if (failed(destBitsRequired) || *destBitsRequired >= *origBitsRequired)
+      return failure();
+
+    FailureOr<unsigned> insertedBitsRequired =
+        calculateBitsRequired(insValue->getOperands().front(), kind);
+    if (failed(insertedBitsRequired) ||
+        *insertedBitsRequired >= *origBitsRequired)
+      return failure();
+
+    // Find a narrower element type that satisfies the bitwidth requirements of
+    // both the source and the destination values.
+    unsigned newInsertionBits =
+        std::max(*destBitsRequired, *insertedBitsRequired);
+    FailureOr<Type> newVecTy = getNarrowType(newInsertionBits, op.getType());
+    if (failed(newVecTy) || *newVecTy == op.getType())
+      return failure();
+
+    FailureOr<Type> newInsertedValueTy =
+        getNarrowType(newInsertionBits, insValue->getResultTypes().front());
+    if (failed(newInsertedValueTy))
+      return failure();
+
+    Location loc = op.getLoc();
+    Value narrowValue = rewriter.createOrFold<arith::TruncIOp>(
+        loc, *newInsertedValueTy, insValue->getResult(0));
+    Value narrowDest =
+        rewriter.createOrFold<arith::TruncIOp>(loc, *newVecTy, op.getDest());
+    return rewriter.create<vector::InsertOp>(loc, narrowValue, narrowDest,
+                                             op.getPosition());
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // Pass Definitions
 //===----------------------------------------------------------------------===//
@@ -249,8 +370,8 @@ void populateArithIntNarrowingPatterns(
   // Add commute patterns with a higher benefit. This is to expose more
   // optimization opportunities to narrowing patterns.
   patterns.add<ExtensionOverExtract, ExtensionOverExtractElement,
-               ExtensionOverExtractStridedSlice>(patterns.getContext(),
-                                                 PatternBenefit(2));
+               ExtensionOverExtractStridedSlice, ExtensionOverInsert>(
+      patterns.getContext(), options, PatternBenefit(2));
 
   patterns.add<SIToFPPattern, UIToFPPattern>(patterns.getContext(), options);
 }

diff  --git a/mlir/test/Dialect/Arith/int-narrowing.mlir b/mlir/test/Dialect/Arith/int-narrowing.mlir
index f1290e552fd77..d98e03d93b030 100644
--- a/mlir/test/Dialect/Arith/int-narrowing.mlir
+++ b/mlir/test/Dialect/Arith/int-narrowing.mlir
@@ -235,3 +235,96 @@ func.func @extui_over_extract_strided_slice_2d(%a: vector<2x3xi16>) -> vector<1x
    {offsets = [1, 1], sizes = [1, 2], strides = [1, 1]} : vector<2x3xi32> to vector<1x2xi32>
   return %c : vector<1x2xi32>
 }
+
+// CHECK-LABEL: func.func @extsi_over_insert_3xi16
+// CHECK-SAME:    (%[[ARG0:.+]]: vector<3xi16>, %[[ARG1:.+]]: i16)
+// CHECK-NEXT:    %[[INS:.+]] = vector.insert %[[ARG1]], %[[ARG0]] [1] : i16 into vector<3xi16>
+// CHECK-NEXT:    %[[RET:.+]] = arith.extsi %[[INS]] : vector<3xi16> to vector<3xi32>
+// CHECK-NEXT:    return %[[RET]] : vector<3xi32>
+func.func @extsi_over_insert_3xi16(%a: vector<3xi16>, %b: i16) -> vector<3xi32> {
+  %c = arith.extsi %a : vector<3xi16> to vector<3xi32>
+  %d = arith.extsi %b : i16 to i32
+  %e = vector.insert %d, %c [1] : i32 into vector<3xi32>
+  return %e : vector<3xi32>
+}
+
+// CHECK-LABEL: func.func @extui_over_insert_3xi16
+// CHECK-SAME:    (%[[ARG0:.+]]: vector<3xi16>, %[[ARG1:.+]]: i16)
+// CHECK-NEXT:    %[[INS:.+]] = vector.insert %[[ARG1]], %[[ARG0]] [1] : i16 into vector<3xi16>
+// CHECK-NEXT:    %[[RET:.+]] = arith.extui %[[INS]] : vector<3xi16> to vector<3xi32>
+// CHECK-NEXT:    return %[[RET]] : vector<3xi32>
+func.func @extui_over_insert_3xi16(%a: vector<3xi16>, %b: i16) -> vector<3xi32> {
+  %c = arith.extui %a : vector<3xi16> to vector<3xi32>
+  %d = arith.extui %b : i16 to i32
+  %e = vector.insert %d, %c [1] : i32 into vector<3xi32>
+  return %e : vector<3xi32>
+}
+
+// CHECK-LABEL: func.func @extsi_over_insert_3xi16_cst_0
+// CHECK-SAME:    (%[[ARG:.+]]: i16)
+// CHECK-NEXT:    %[[CST:.+]] = arith.constant dense<0> : vector<3xi16>
+// CHECK-NEXT:    %[[INS:.+]] = vector.insert %[[ARG]], %[[CST]] [1] : i16 into vector<3xi16>
+// CHECK-NEXT:    %[[RET:.+]] = arith.extsi %[[INS]] : vector<3xi16> to vector<3xi32>
+// CHECK-NEXT:    return %[[RET]] : vector<3xi32>
+func.func @extsi_over_insert_3xi16_cst_0(%a: i16) -> vector<3xi32> {
+  %cst = arith.constant dense<0> : vector<3xi32>
+  %d = arith.extsi %a : i16 to i32
+  %e = vector.insert %d, %cst [1] : i32 into vector<3xi32>
+  return %e : vector<3xi32>
+}
+
+// CHECK-LABEL: func.func @extsi_over_insert_3xi8_cst
+// CHECK-SAME:    (%[[ARG:.+]]: i8)
+// CHECK-NEXT:    %[[CST:.+]] = arith.constant dense<[-1, 127, -128]> : vector<3xi8>
+// CHECK-NEXT:    %[[INS:.+]] = vector.insert %[[ARG]], %[[CST]] [1] : i8 into vector<3xi8>
+// CHECK-NEXT:    %[[RET:.+]] = arith.extsi %[[INS]] : vector<3xi8> to vector<3xi32>
+// CHECK-NEXT:    return %[[RET]] : vector<3xi32>
+func.func @extsi_over_insert_3xi8_cst(%a: i8) -> vector<3xi32> {
+  %cst = arith.constant dense<[-1, 127, -128]> : vector<3xi32>
+  %d = arith.extsi %a : i8 to i32
+  %e = vector.insert %d, %cst [1] : i32 into vector<3xi32>
+  return %e : vector<3xi32>
+}
+
+// CHECK-LABEL: func.func @extui_over_insert_3xi8_cst
+// CHECK-SAME:    (%[[ARG:.+]]: i8)
+// CHECK-NEXT:    %[[CST:.+]] = arith.constant dense<[1, 127, -1]> : vector<3xi8>
+// CHECK-NEXT:    %[[INS:.+]] = vector.insert %[[ARG]], %[[CST]] [1] : i8 into vector<3xi8>
+// CHECK-NEXT:    %[[RET:.+]] = arith.extui %[[INS]] : vector<3xi8> to vector<3xi32>
+// CHECK-NEXT:    return %[[RET]] : vector<3xi32>
+func.func @extui_over_insert_3xi8_cst(%a: i8) -> vector<3xi32> {
+  %cst = arith.constant dense<[1, 127, 255]> : vector<3xi32>
+  %d = arith.extui %a : i8 to i32
+  %e = vector.insert %d, %cst [1] : i32 into vector<3xi32>
+  return %e : vector<3xi32>
+}
+
+// CHECK-LABEL: func.func @extsi_over_insert_3xi16_cst_i16
+// CHECK-SAME:    (%[[ARG:.+]]: i8)
+// CHECK-NEXT:    %[[CST:.+]]  = arith.constant dense<[-1, 128, 0]> : vector<3xi16>
+// CHECK-NEXT:    %[[SRCE:.+]] = arith.extsi %[[ARG]] : i8 to i32
+// CHECK-NEXT:    %[[SRCT:.+]] = arith.trunci %[[SRCE]] : i32 to i16
+// CHECK-NEXT:    %[[INS:.+]]  = vector.insert %[[SRCT]], %[[CST]] [1] : i16 into vector<3xi16>
+// CHECK-NEXT:    %[[RET:.+]]  = arith.extsi %[[INS]] : vector<3xi16> to vector<3xi32>
+// CHECK-NEXT:    return %[[RET]] : vector<3xi32>
+func.func @extsi_over_insert_3xi16_cst_i16(%a: i8) -> vector<3xi32> {
+  %cst = arith.constant dense<[-1, 128, 0]> : vector<3xi32>
+  %d = arith.extsi %a : i8 to i32
+  %e = vector.insert %d, %cst [1] : i32 into vector<3xi32>
+  return %e : vector<3xi32>
+}
+
+// CHECK-LABEL: func.func @extui_over_insert_3xi16_cst_i16
+// CHECK-SAME:    (%[[ARG:.+]]: i8)
+// CHECK-NEXT:    %[[CST:.+]]  = arith.constant dense<[1, 256, 0]> : vector<3xi16>
+// CHECK-NEXT:    %[[SRCE:.+]] = arith.extui %[[ARG]] : i8 to i32
+// CHECK-NEXT:    %[[SRCT:.+]] = arith.trunci %[[SRCE]] : i32 to i16
+// CHECK-NEXT:    %[[INS:.+]]  = vector.insert %[[SRCT]], %[[CST]] [1] : i16 into vector<3xi16>
+// CHECK-NEXT:    %[[RET:.+]]  = arith.extui %[[INS]] : vector<3xi16> to vector<3xi32>
+// CHECK-NEXT:    return %[[RET]] : vector<3xi32>
+func.func @extui_over_insert_3xi16_cst_i16(%a: i8) -> vector<3xi32> {
+  %cst = arith.constant dense<[1, 256, 0]> : vector<3xi32>
+  %d = arith.extui %a : i8 to i32
+  %e = vector.insert %d, %cst [1] : i32 into vector<3xi32>
+  return %e : vector<3xi32>
+}


        


More information about the Mlir-commits mailing list