[Mlir-commits] [mlir] fc19204 - [mlir][arith] Add narrowing patterns for index casts
Jakub Kuderski
llvmlistbot at llvm.org
Wed May 3 10:56:21 PDT 2023
Author: Jakub Kuderski
Date: 2023-05-03T13:55:02-04:00
New Revision: fc19204918136074483e576177e05bccb2543d44
URL: https://github.com/llvm/llvm-project/commit/fc19204918136074483e576177e05bccb2543d44
DIFF: https://github.com/llvm/llvm-project/commit/fc19204918136074483e576177e05bccb2543d44.diff
LOG: [mlir][arith] Add narrowing patterns for index casts
These rely on the `ValueBounds` interace and its utility function to
compute constant bounds. This allows us to optimize `linalg.index`
values cast to integer types.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D149538
Added:
mlir/test/Dialect/Linalg/int-narrowing.mlir
Modified:
mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
index 344caff0b5850..4f294e6e4c91e 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
@@ -8,7 +8,9 @@
#include "mlir/Dialect/Arith/Transforms/Passes.h"
+#include "mlir/Analysis/Presburger/IntegerRelation.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Transforms/Transforms.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
@@ -18,6 +20,7 @@
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
@@ -419,6 +422,65 @@ struct IToFPPattern final : NarrowingPattern<IToFPOp> {
using SIToFPPattern = IToFPPattern<arith::SIToFPOp, ExtensionKind::Sign>;
using UIToFPPattern = IToFPPattern<arith::UIToFPOp, ExtensionKind::Zero>;
+//===----------------------------------------------------------------------===//
+// Index Cast Patterns
+//===----------------------------------------------------------------------===//
+
+// These rely on the `ValueBounds` interface for index values. For example, we
+// can often statically tell index value bounds of loop induction variables.
+
+template <typename CastOp, ExtensionKind Kind>
+struct IndexCastPattern final : NarrowingPattern<CastOp> {
+ using NarrowingPattern<CastOp>::NarrowingPattern;
+
+ LogicalResult matchAndRewrite(CastOp op,
+ PatternRewriter &rewriter) const override {
+ Value in = op.getIn();
+ // We only support scalar index -> integer casts.
+ if (!isa<IndexType>(in.getType()))
+ return failure();
+
+ // Check the lower bound in both the signed and unsigned cast case. We
+ // conservatively assume that even unsigned casts may be performed on
+ // negative indices.
+ FailureOr<int64_t> lb = ValueBoundsConstraintSet::computeConstantBound(
+ presburger::BoundType::LB, in);
+ if (failed(lb))
+ return failure();
+
+ FailureOr<int64_t> ub = ValueBoundsConstraintSet::computeConstantBound(
+ presburger::BoundType::UB, in, /*dim=*/std::nullopt,
+ /*stopCondition=*/nullptr, /*closedUB=*/true);
+ if (failed(ub))
+ return failure();
+
+ assert(*lb <= *ub && "Invalid bounds");
+ unsigned lbBitsRequired = calculateBitsRequired(APInt(64, *lb), Kind);
+ unsigned ubBitsRequired = calculateBitsRequired(APInt(64, *ub), Kind);
+ unsigned bitsRequired = std::max(lbBitsRequired, ubBitsRequired);
+
+ IntegerType resultTy = cast<IntegerType>(op.getType());
+ if (resultTy.getWidth() <= bitsRequired)
+ return failure();
+
+ FailureOr<Type> narrowTy = this->getNarrowType(bitsRequired, resultTy);
+ if (failed(narrowTy))
+ return failure();
+
+ Value newCast = rewriter.create<CastOp>(op.getLoc(), *narrowTy, op.getIn());
+
+ if (Kind == ExtensionKind::Sign)
+ rewriter.replaceOpWithNewOp<arith::ExtSIOp>(op, resultTy, newCast);
+ else
+ rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, resultTy, newCast);
+ return success();
+ }
+};
+using IndexCastSIPattern =
+ IndexCastPattern<arith::IndexCastOp, ExtensionKind::Sign>;
+using IndexCastUIPattern =
+ IndexCastPattern<arith::IndexCastUIOp, ExtensionKind::Zero>;
+
//===----------------------------------------------------------------------===//
// Patterns to Commute Extension Ops
//===----------------------------------------------------------------------===//
@@ -714,8 +776,8 @@ void populateArithIntNarrowingPatterns(
patterns.add<AddIPattern, SubIPattern, MulIPattern, DivSIPattern,
DivUIPattern, MaxSIPattern, MaxUIPattern, MinSIPattern,
- MinUIPattern, SIToFPPattern, UIToFPPattern>(
- patterns.getContext(), options);
+ MinUIPattern, SIToFPPattern, UIToFPPattern, IndexCastSIPattern,
+ IndexCastUIPattern>(patterns.getContext(), options);
}
} // namespace mlir::arith
diff --git a/mlir/test/Dialect/Linalg/int-narrowing.mlir b/mlir/test/Dialect/Linalg/int-narrowing.mlir
new file mode 100644
index 0000000000000..8063d504597a3
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/int-narrowing.mlir
@@ -0,0 +1,147 @@
+// RUN: mlir-opt --arith-int-narrowing="int-bitwidths-supported=1,8,16,32" \
+// RUN: --verify-diagnostics %s | FileCheck %s
+
+// Check that we can calculate `linalg.index` value bounds and use them to
+// optimize index casts.
+
+//===----------------------------------------------------------------------===//
+// arith.index_cast
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func @linalg_indexcast_dim_0_i8
+// CHECK: %[[IDX:.+]] = linalg.index 0 : index
+// CHECK-NEXT: %[[INT:.+]] = arith.index_cast %[[IDX]] : index to i8
+// CHECK-NEXT: %[[FP:.+]] = arith.sitofp %[[INT]] : i8 to f16
+// CHECK-NEXT: linalg.yield %[[FP]] : f16
+func.func @linalg_indexcast_dim_0_i8(%arg0: tensor<f16>) -> tensor<128xf16> {
+ %init = tensor.empty() : tensor<128xf16>
+ %res = linalg.generic {
+ indexing_maps = [affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>],
+ iterator_types = ["parallel"]
+ }
+ ins(%arg0 : tensor<f16>)
+ outs(%init : tensor<128xf16>) {
+ ^bb0(%in: f16, %out: f16):
+ %idx = linalg.index 0 : index
+ %int = arith.index_cast %idx : index to i64
+ %fp = arith.sitofp %int : i64 to f16
+ linalg.yield %fp : f16
+ } -> tensor<128xf16>
+
+ return %res : tensor<128xf16>
+}
+
+// CHECK-LABEL: func @linalg_indexcast_dim_1_i16
+// CHECK: %[[IDX:.+]] = linalg.index 1 : index
+// CHECK-NEXT: %[[INT:.+]] = arith.index_cast %[[IDX]] : index to i16
+// CHECK-NEXT: %[[FP:.+]] = arith.sitofp %[[INT]] : i16 to f16
+// CHECK-NEXT: linalg.yield %[[FP]] : f16
+func.func @linalg_indexcast_dim_1_i16(%arg0: tensor<f16>, %arg1: tensor<?x129xf16>) -> tensor<?x129xf16> {
+ %res = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]
+ }
+ ins(%arg0 : tensor<f16>)
+ outs(%arg1 : tensor<?x129xf16>) {
+ ^bb0(%in: f16, %out: f16):
+ %idx = linalg.index 1 : index
+ %int = arith.index_cast %idx : index to i64
+ %fp = arith.sitofp %int : i64 to f16
+ linalg.yield %fp : f16
+ } -> tensor<?x129xf16>
+
+ return %res : tensor<?x129xf16>
+}
+
+// CHECK-LABEL: func @linalg_indexcast_dynamic_dim_i64
+// CHECK: %[[IDX:.+]] = linalg.index 0 : index
+// CHECK-NEXT: %[[INT:.+]] = arith.index_cast %[[IDX]] : index to i64
+// CHECK-NEXT: %[[FP:.+]] = arith.sitofp %[[INT]] : i64 to f16
+// CHECK-NEXT: linalg.yield %[[FP]] : f16
+func.func @linalg_indexcast_dynamic_dim_i64(%arg0: tensor<f16>, %arg1: tensor<?xf16>) -> tensor<?xf16> {
+ %res = linalg.generic {
+ indexing_maps = [affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>],
+ iterator_types = ["parallel"]
+ }
+ ins(%arg0 : tensor<f16>)
+ outs(%arg1 : tensor<?xf16>) {
+ ^bb0(%in: f16, %out: f16):
+ %idx = linalg.index 0 : index
+ %int = arith.index_cast %idx : index to i64
+ %fp = arith.sitofp %int : i64 to f16
+ linalg.yield %fp : f16
+ } -> tensor<?xf16>
+
+ return %res : tensor<?xf16>
+}
+
+//===----------------------------------------------------------------------===//
+// arith.index_castui
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func @linalg_indexcastui_dim_0_i8
+// CHECK: %[[IDX:.+]] = linalg.index 0 : index
+// CHECK-NEXT: %[[INT:.+]] = arith.index_castui %[[IDX]] : index to i8
+// CHECK-NEXT: %[[FP:.+]] = arith.uitofp %[[INT]] : i8 to f16
+// CHECK-NEXT: linalg.yield %[[FP]] : f16
+func.func @linalg_indexcastui_dim_0_i8(%arg0: tensor<f16>) -> tensor<256xf16> {
+ %init = tensor.empty() : tensor<256xf16>
+ %res = linalg.generic {
+ indexing_maps = [affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>],
+ iterator_types = ["parallel"]
+ }
+ ins(%arg0 : tensor<f16>)
+ outs(%init : tensor<256xf16>) {
+ ^bb0(%in: f16, %out: f16):
+ %idx = linalg.index 0 : index
+ %int = arith.index_castui %idx : index to i64
+ %fp = arith.uitofp %int : i64 to f16
+ linalg.yield %fp : f16
+ } -> tensor<256xf16>
+
+ return %res : tensor<256xf16>
+}
+
+// CHECK-LABEL: func @linalg_indexcastui_dim_1_i16
+// CHECK: %[[IDX:.+]] = linalg.index 1 : index
+// CHECK-NEXT: %[[INT:.+]] = arith.index_castui %[[IDX]] : index to i16
+// CHECK-NEXT: %[[FP:.+]] = arith.uitofp %[[INT]] : i16 to f16
+// CHECK-NEXT: linalg.yield %[[FP]] : f16
+func.func @linalg_indexcastui_dim_1_i16(%arg0: tensor<f16>, %arg1: tensor<?x257xf16>) -> tensor<?x257xf16> {
+ %res = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]
+ }
+ ins(%arg0 : tensor<f16>)
+ outs(%arg1 : tensor<?x257xf16>) {
+ ^bb0(%in: f16, %out: f16):
+ %idx = linalg.index 1 : index
+ %int = arith.index_castui %idx : index to i64
+ %fp = arith.uitofp %int : i64 to f16
+ linalg.yield %fp : f16
+ } -> tensor<?x257xf16>
+
+ return %res : tensor<?x257xf16>
+}
+
+// CHECK-LABEL: func @linalg_indexcastui_dynamic_dim_i64
+// CHECK: %[[IDX:.+]] = linalg.index 0 : index
+// CHECK-NEXT: %[[INT:.+]] = arith.index_castui %[[IDX]] : index to i64
+// CHECK-NEXT: %[[FP:.+]] = arith.uitofp %[[INT]] : i64 to f16
+// CHECK-NEXT: linalg.yield %[[FP]] : f16
+func.func @linalg_indexcastui_dynamic_dim_i64(%arg0: tensor<f16>, %arg1: tensor<?xf16>) -> tensor<?xf16> {
+ %res = linalg.generic {
+ indexing_maps = [affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>],
+ iterator_types = ["parallel"]
+ }
+ ins(%arg0 : tensor<f16>)
+ outs(%arg1 : tensor<?xf16>) {
+ ^bb0(%in: f16, %out: f16):
+ %idx = linalg.index 0 : index
+ %int = arith.index_castui %idx : index to i64
+ %fp = arith.uitofp %int : i64 to f16
+ linalg.yield %fp : f16
+ } -> tensor<?xf16>
+
+ return %res : tensor<?xf16>
+}
More information about the Mlir-commits
mailing list