[Mlir-commits] [mlir] da0730b - [mlir][arith] Add initial integer bitwidth narrowing pass

Jakub Kuderski llvmlistbot at llvm.org
Tue Apr 25 19:33:58 PDT 2023


Author: Jakub Kuderski
Date: 2023-04-25T22:33:11-04:00
New Revision: da0730b908a43e490430717beda8486598667ab8

URL: https://github.com/llvm/llvm-project/commit/da0730b908a43e490430717beda8486598667ab8
DIFF: https://github.com/llvm/llvm-project/commit/da0730b908a43e490430717beda8486598667ab8.diff

LOG: [mlir][arith] Add initial integer bitwidth narrowing pass

This pass reduces the logical complexity of arith ops by choosing
narrowest supported operand bitwidth. On some targets like mobile GPUs,
narrower bitwidths also bring better runtime performance.

The first batch of rewrites handles a simple case of `arith.sitofp`
and `arith.uitofp` with zero/sign-extended inputs. In future revisions,
I plan to extend it with the following:
-  Propagating sign/zero-extensions through bit-pattern-preserving ops,
   e.g., vector transpose, broadcast, insertions/extractions.
-  Handling `linalg.index` using the `ValueBounds` interface.
-  Handling more arith ops.

Reviewed By: springerm, antiagainst

Differential Revision: https://reviews.llvm.org/D149118

Added: 
    mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
    mlir/test/Dialect/Arith/int-narrowing.mlir

Modified: 
    mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
    mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
    mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
index 1aa2e55dfdc98..c4010b7c0b57a 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
@@ -62,6 +62,10 @@ void populateIntRangeOptimizationsPatterns(RewritePatternSet &patterns,
 /// Create a pass which do optimizations based on integer range analysis.
 std::unique_ptr<Pass> createIntRangeOptimizationsPass();
 
+/// Add patterns for integer bitwidth narrowing.
+void populateArithIntNarrowingPatterns(RewritePatternSet &patterns,
+                                       const ArithIntNarrowingOptions &options);
+
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
index 29b5e6f1dee86..e6fe4680d77cb 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
@@ -83,4 +83,17 @@ def ArithEmulateWideInt : Pass<"arith-emulate-wide-int"> {
   let dependentDialects = ["vector::VectorDialect"];
 }
 
+def ArithIntNarrowing : Pass<"arith-int-narrowing"> {
+  let summary = "Reduce integer operation bitwidth";
+  let description = [{
+    Reduce bitwidths of integer types used in arith operations. This pass
+    prefers the narrowest available integer bitwidths that are guaranteed to
+    produce the same results.
+  }];
+  let options = [
+    ListOption<"bitwidthsSupported", "int-bitwidths-supported", "unsigned",
+               "Integer bitwidths supported">,
+  ];
+ }
+
 #endif // MLIR_DIALECT_ARITH_TRANSFORMS_PASSES

diff  --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
index 80a612883e396..87d9bebfd2a7c 100644
--- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRArithTransforms
   Bufferize.cpp
   EmulateWideInt.cpp
   ExpandOps.cpp
+  IntNarrowing.cpp
   IntRangeOptimizations.cpp
   ReifyValueBounds.cpp
   UnsignedWhenEquivalent.cpp

diff  --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
new file mode 100644
index 0000000000000..e884a19bae1cb
--- /dev/null
+++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
@@ -0,0 +1,175 @@
+//===- IntNarrowing.cpp - Integer bitwidth reduction optimizations --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/Transforms/Passes.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include <cassert>
+#include <cstdint>
+
+namespace mlir::arith {
+#define GEN_PASS_DEF_ARITHINTNARROWING
+#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
+} // namespace mlir::arith
+
+namespace mlir::arith {
+namespace {
+//===----------------------------------------------------------------------===//
+// Common Helpers
+//===----------------------------------------------------------------------===//
+
+/// The base for integer bitwidth narrowing patterns.
+template <typename SourceOp>
+struct NarrowingPattern : OpRewritePattern<SourceOp> {
+  NarrowingPattern(MLIRContext *ctx, const ArithIntNarrowingOptions &options,
+                   PatternBenefit benefit = 1)
+      : OpRewritePattern<SourceOp>(ctx, benefit),
+        supportedBitwidths(options.bitwidthsSupported.begin(),
+                           options.bitwidthsSupported.end()) {
+    assert(!supportedBitwidths.empty() && "Invalid options");
+    assert(!llvm::is_contained(supportedBitwidths, 0) && "Invalid bitwidth");
+    llvm::sort(supportedBitwidths);
+  }
+
+  FailureOr<unsigned>
+  getNarrowestCompatibleBitwidth(unsigned bitsRequired) const {
+    for (unsigned candidate : supportedBitwidths)
+      if (candidate >= bitsRequired)
+        return candidate;
+
+    return failure();
+  }
+
+  /// Returns the narrowest supported type that fits `bitsRequired`.
+  FailureOr<Type> getNarrowType(unsigned bitsRequired, Type origTy) const {
+    assert(origTy);
+    FailureOr<unsigned> bestBitwidth =
+        getNarrowestCompatibleBitwidth(bitsRequired);
+    if (failed(bestBitwidth))
+      return failure();
+
+    Type elemTy = getElementTypeOrSelf(origTy);
+    if (!isa<IntegerType>(elemTy))
+      return failure();
+
+    auto newElemTy = IntegerType::get(origTy.getContext(), bitsRequired);
+    if (newElemTy == elemTy)
+      return failure();
+
+    if (origTy == elemTy)
+      return newElemTy;
+
+    if (auto shapedTy = dyn_cast<ShapedType>(origTy))
+      if (auto elemTy = dyn_cast<IntegerType>(shapedTy.getElementType()))
+        return shapedTy.clone(shapedTy.getShape(), newElemTy);
+
+    return failure();
+  }
+
+private:
+  // Supported integer bitwidths in the ascending order.
+  llvm::SmallVector<unsigned, 6> supportedBitwidths;
+};
+
+/// Returns the integer bitwidth required to represent `type`.
+FailureOr<unsigned> calculateBitsRequired(Type type) {
+  assert(type);
+  if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(type)))
+    return intTy.getWidth();
+
+  return failure();
+}
+
+enum class ExtensionKind { Sign, Zero };
+
+/// Returns the integer bitwidth required to represent `value`.
+/// Looks through either sign- or zero-extension as specified by
+/// `lookThroughExtension`.
+FailureOr<unsigned> calculateBitsRequired(Value value,
+                                          ExtensionKind lookThroughExtension) {
+  if (lookThroughExtension == ExtensionKind::Sign) {
+    if (auto sext = value.getDefiningOp<arith::ExtSIOp>())
+      return calculateBitsRequired(sext.getIn().getType());
+  } else if (lookThroughExtension == ExtensionKind::Zero) {
+    if (auto zext = value.getDefiningOp<arith::ExtUIOp>())
+      return calculateBitsRequired(zext.getIn().getType());
+  }
+
+  // If nothing else worked, return the type requirements for this element type.
+  return calculateBitsRequired(value.getType());
+}
+
+//===----------------------------------------------------------------------===//
+// *IToFPOp Patterns
+//===----------------------------------------------------------------------===//
+
+template <typename IToFPOp, ExtensionKind Extension>
+struct IToFPPattern final : NarrowingPattern<IToFPOp> {
+  using NarrowingPattern<IToFPOp>::NarrowingPattern;
+
+  LogicalResult matchAndRewrite(IToFPOp op,
+                                PatternRewriter &rewriter) const override {
+    FailureOr<unsigned> narrowestWidth =
+        calculateBitsRequired(op.getIn(), Extension);
+    if (failed(narrowestWidth))
+      return failure();
+
+    FailureOr<Type> narrowTy =
+        this->getNarrowType(*narrowestWidth, op.getIn().getType());
+    if (failed(narrowTy))
+      return failure();
+
+    Value newIn = rewriter.createOrFold<arith::TruncIOp>(op.getLoc(), *narrowTy,
+                                                         op.getIn());
+    rewriter.replaceOpWithNewOp<IToFPOp>(op, op.getType(), newIn);
+    return success();
+  }
+};
+using SIToFPPattern = IToFPPattern<arith::SIToFPOp, ExtensionKind::Sign>;
+using UIToFPPattern = IToFPPattern<arith::UIToFPOp, ExtensionKind::Zero>;
+
+//===----------------------------------------------------------------------===//
+// Pass Definitions
+//===----------------------------------------------------------------------===//
+
+struct ArithIntNarrowingPass final
+    : impl::ArithIntNarrowingBase<ArithIntNarrowingPass> {
+  using ArithIntNarrowingBase::ArithIntNarrowingBase;
+
+  void runOnOperation() override {
+    Operation *op = getOperation();
+    MLIRContext *ctx = op->getContext();
+    RewritePatternSet patterns(ctx);
+    populateArithIntNarrowingPatterns(
+        patterns, ArithIntNarrowingOptions{bitwidthsSupported});
+    if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
+      signalPassFailure();
+  }
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Public API
+//===----------------------------------------------------------------------===//
+
+void populateArithIntNarrowingPatterns(
+    RewritePatternSet &patterns, const ArithIntNarrowingOptions &options) {
+  patterns.add<SIToFPPattern, UIToFPPattern>(patterns.getContext(), options);
+}
+
+} // namespace mlir::arith

diff  --git a/mlir/test/Dialect/Arith/int-narrowing.mlir b/mlir/test/Dialect/Arith/int-narrowing.mlir
new file mode 100644
index 0000000000000..21d5ab774c87b
--- /dev/null
+++ b/mlir/test/Dialect/Arith/int-narrowing.mlir
@@ -0,0 +1,133 @@
+// RUN: mlir-opt --arith-int-narrowing="int-bitwidths-supported=1,8,16,32" \
+// RUN:          --verify-diagnostics %s | FileCheck %s
+
+// CHECK-LABEL: func.func @sitofp_extsi_i16
+// CHECK-SAME:    (%[[ARG:.+]]: i16)
+// CHECK-NEXT:    %[[RET:.+]] = arith.sitofp %[[ARG]] : i16 to f16
+// CHECK-NEXT:    return %[[RET]] : f16
+func.func @sitofp_extsi_i16(%a: i16) -> f16 {
+  %b = arith.extsi %a : i16 to i32
+  %f = arith.sitofp %b : i32 to f16
+  return %f : f16
+}
+
+// CHECK-LABEL: func.func @sitofp_extsi_vector_i16
+// CHECK-SAME:    (%[[ARG:.+]]: vector<3xi16>)
+// CHECK-NEXT:    %[[RET:.+]] = arith.sitofp %[[ARG]] : vector<3xi16> to vector<3xf16>
+// CHECK-NEXT:    return %[[RET]] : vector<3xf16>
+func.func @sitofp_extsi_vector_i16(%a: vector<3xi16>) -> vector<3xf16> {
+  %b = arith.extsi %a : vector<3xi16> to vector<3xi32>
+  %f = arith.sitofp %b : vector<3xi32> to vector<3xf16>
+  return %f : vector<3xf16>
+}
+
+// CHECK-LABEL: func.func @sitofp_extsi_tensor_i16
+// CHECK-SAME:    (%[[ARG:.+]]: tensor<3x?xi16>)
+// CHECK-NEXT:    %[[RET:.+]] = arith.sitofp %[[ARG]] : tensor<3x?xi16> to tensor<3x?xf16>
+// CHECK-NEXT:    return %[[RET]] : tensor<3x?xf16>
+func.func @sitofp_extsi_tensor_i16(%a: tensor<3x?xi16>) -> tensor<3x?xf16> {
+  %b = arith.extsi %a : tensor<3x?xi16> to tensor<3x?xi32>
+  %f = arith.sitofp %b : tensor<3x?xi32> to tensor<3x?xf16>
+  return %f : tensor<3x?xf16>
+}
+
+// Narrowing to i64 is not enabled in pass options.
+//
+// CHECK-LABEL: func.func @sitofp_extsi_i64
+// CHECK-SAME:    (%[[ARG:.+]]: i64)
+// CHECK-NEXT:    %[[EXT:.+]] = arith.extsi %[[ARG]] : i64 to i128
+// CHECK-NEXT:    %[[RET:.+]] = arith.sitofp %[[EXT]] : i128 to f32
+// CHECK-NEXT:    return %[[RET]] : f32
+func.func @sitofp_extsi_i64(%a: i64) -> f32 {
+  %b = arith.extsi %a : i64 to i128
+  %f = arith.sitofp %b : i128 to f32
+  return %f : f32
+}
+
+// CHECK-LABEL: func.func @uitofp_extui_i16
+// CHECK-SAME:    (%[[ARG:.+]]: i16)
+// CHECK-NEXT:    %[[RET:.+]] = arith.uitofp %[[ARG]] : i16 to f16
+// CHECK-NEXT:    return %[[RET]] : f16
+func.func @uitofp_extui_i16(%a: i16) -> f16 {
+  %b = arith.extui %a : i16 to i32
+  %f = arith.uitofp %b : i32 to f16
+  return %f : f16
+}
+
+// CHECK-LABEL: func.func @sitofp_extsi_extsi_i8
+// CHECK-SAME:    (%[[ARG:.+]]: i8)
+// CHECK-NEXT:    %[[RET:.+]] = arith.sitofp %[[ARG]] : i8 to f16
+// CHECK-NEXT:    return %[[RET]] : f16
+func.func @sitofp_extsi_extsi_i8(%a: i8) -> f16 {
+  %b = arith.extsi %a : i8 to i16
+  %c = arith.extsi %b : i16 to i32
+  %f = arith.sitofp %c : i32 to f16
+  return %f : f16
+}
+
+// CHECK-LABEL: func.func @uitofp_extui_extui_i8
+// CHECK-SAME:    (%[[ARG:.+]]: i8)
+// CHECK-NEXT:    %[[RET:.+]] = arith.uitofp %[[ARG]] : i8 to f16
+// CHECK-NEXT:    return %[[RET]] : f16
+func.func @uitofp_extui_extui_i8(%a: i8) -> f16 {
+  %b = arith.extui %a : i8 to i16
+  %c = arith.extui %b : i16 to i32
+  %f = arith.uitofp %c : i32 to f16
+  return %f : f16
+}
+
+// CHECK-LABEL: func.func @uitofp_extsi_extui_i8
+// CHECK-SAME:    (%[[ARG:.+]]: i8)
+// CHECK-NEXT:    %[[EXT:.+]] = arith.extsi %[[ARG]] : i8 to i16
+// CHECK-NEXT:    %[[RET:.+]] = arith.uitofp %[[EXT]] : i16 to f16
+// CHECK-NEXT:    return %[[RET]] : f16
+func.func @uitofp_extsi_extui_i8(%a: i8) -> f16 {
+  %b = arith.extsi %a : i8 to i16
+  %c = arith.extui %b : i16 to i32
+  %f = arith.uitofp %c : i32 to f16
+  return %f : f16
+}
+
+// CHECK-LABEL: func.func @uitofp_trunci_extui_i8
+// CHECK-SAME:    (%[[ARG:.+]]: i16)
+// CHECK-NEXT:    %[[TR:.+]]  = arith.trunci %[[ARG]] : i16 to i8
+// CHECK-NEXT:    %[[RET:.+]] = arith.uitofp %[[TR]] : i8 to f16
+// CHECK-NEXT:    return %[[RET]] : f16
+func.func @uitofp_trunci_extui_i8(%a: i16) -> f16 {
+  %b = arith.trunci %a : i16 to i8
+  %c = arith.extui %b : i8 to i32
+  %f = arith.uitofp %c : i32 to f16
+  return %f : f16
+}
+
+// This should not be folded because arith.extui changes the signed
+// range of the number. For example:
+//  extsi -1 : i16 to i32 ==> -1
+//  extui -1 : i16 to i32 ==> U16_MAX
+//
+/// CHECK-LABEL: func.func @sitofp_extui_i16
+// CHECK-SAME:    (%[[ARG:.+]]: i16)
+// CHECK-NEXT:    %[[EXT:.+]] = arith.extui %[[ARG]] : i16 to i32
+// CHECK-NEXT:    %[[RET:.+]] = arith.sitofp %[[EXT]] : i32 to f16
+// CHECK-NEXT:    return %[[RET]] : f16
+func.func @sitofp_extui_i16(%a: i16) -> f16 {
+  %b = arith.extui %a : i16 to i32
+  %f = arith.sitofp %b : i32 to f16
+  return %f : f16
+}
+
+// This should not be folded because arith.extsi changes the unsigned
+// range of the number. For example:
+//  extsi -1 : i16 to i32 ==> U32_MAX
+//  extui -1 : i16 to i32 ==> U16_MAX
+//
+// CHECK-LABEL: func.func @uitofp_extsi_i16
+// CHECK-SAME:    (%[[ARG:.+]]: i16)
+// CHECK-NEXT:    %[[EXT:.+]] = arith.extsi %[[ARG]] : i16 to i32
+// CHECK-NEXT:    %[[RET:.+]] = arith.uitofp %[[EXT]] : i32 to f16
+// CHECK-NEXT:    return %[[RET]] : f16
+func.func @uitofp_extsi_i16(%a: i16) -> f16 {
+  %b = arith.extsi %a : i16 to i32
+  %f = arith.uitofp %b : i32 to f16
+  return %f : f16
+}


        


More information about the Mlir-commits mailing list