[Mlir-commits] [mlir] fc19204 - [mlir][arith] Add narrowing patterns for index casts

Jakub Kuderski llvmlistbot at llvm.org
Wed May 3 10:56:21 PDT 2023


Author: Jakub Kuderski
Date: 2023-05-03T13:55:02-04:00
New Revision: fc19204918136074483e576177e05bccb2543d44

URL: https://github.com/llvm/llvm-project/commit/fc19204918136074483e576177e05bccb2543d44
DIFF: https://github.com/llvm/llvm-project/commit/fc19204918136074483e576177e05bccb2543d44.diff

LOG: [mlir][arith] Add narrowing patterns for index casts

These rely on the `ValueBounds` interace and its utility function to
compute constant bounds. This allows us to optimize `linalg.index`
values cast to integer types.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D149538

Added: 
    mlir/test/Dialect/Linalg/int-narrowing.mlir

Modified: 
    mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
index 344caff0b5850..4f294e6e4c91e 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
@@ -8,7 +8,9 @@
 
 #include "mlir/Dialect/Arith/Transforms/Passes.h"
 
+#include "mlir/Analysis/Presburger/IntegerRelation.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Transforms/Transforms.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
@@ -18,6 +20,7 @@
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "mlir/Interfaces/ValueBoundsOpInterface.h"
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/STLExtras.h"
@@ -419,6 +422,65 @@ struct IToFPPattern final : NarrowingPattern<IToFPOp> {
 using SIToFPPattern = IToFPPattern<arith::SIToFPOp, ExtensionKind::Sign>;
 using UIToFPPattern = IToFPPattern<arith::UIToFPOp, ExtensionKind::Zero>;
 
+//===----------------------------------------------------------------------===//
+// Index Cast Patterns
+//===----------------------------------------------------------------------===//
+
+// These rely on the `ValueBounds` interface for index values. For example, we
+// can often statically tell index value bounds of loop induction variables.
+
+template <typename CastOp, ExtensionKind Kind>
+struct IndexCastPattern final : NarrowingPattern<CastOp> {
+  using NarrowingPattern<CastOp>::NarrowingPattern;
+
+  LogicalResult matchAndRewrite(CastOp op,
+                                PatternRewriter &rewriter) const override {
+    Value in = op.getIn();
+    // We only support scalar index -> integer casts.
+    if (!isa<IndexType>(in.getType()))
+      return failure();
+
+    // Check the lower bound in both the signed and unsigned cast case. We
+    // conservatively assume that even unsigned casts may be performed on
+    // negative indices.
+    FailureOr<int64_t> lb = ValueBoundsConstraintSet::computeConstantBound(
+        presburger::BoundType::LB, in);
+    if (failed(lb))
+      return failure();
+
+    FailureOr<int64_t> ub = ValueBoundsConstraintSet::computeConstantBound(
+        presburger::BoundType::UB, in, /*dim=*/std::nullopt,
+        /*stopCondition=*/nullptr, /*closedUB=*/true);
+    if (failed(ub))
+      return failure();
+
+    assert(*lb <= *ub && "Invalid bounds");
+    unsigned lbBitsRequired = calculateBitsRequired(APInt(64, *lb), Kind);
+    unsigned ubBitsRequired = calculateBitsRequired(APInt(64, *ub), Kind);
+    unsigned bitsRequired = std::max(lbBitsRequired, ubBitsRequired);
+
+    IntegerType resultTy = cast<IntegerType>(op.getType());
+    if (resultTy.getWidth() <= bitsRequired)
+      return failure();
+
+    FailureOr<Type> narrowTy = this->getNarrowType(bitsRequired, resultTy);
+    if (failed(narrowTy))
+      return failure();
+
+    Value newCast = rewriter.create<CastOp>(op.getLoc(), *narrowTy, op.getIn());
+
+    if (Kind == ExtensionKind::Sign)
+      rewriter.replaceOpWithNewOp<arith::ExtSIOp>(op, resultTy, newCast);
+    else
+      rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, resultTy, newCast);
+    return success();
+  }
+};
+using IndexCastSIPattern =
+    IndexCastPattern<arith::IndexCastOp, ExtensionKind::Sign>;
+using IndexCastUIPattern =
+    IndexCastPattern<arith::IndexCastUIOp, ExtensionKind::Zero>;
+
 //===----------------------------------------------------------------------===//
 // Patterns to Commute Extension Ops
 //===----------------------------------------------------------------------===//
@@ -714,8 +776,8 @@ void populateArithIntNarrowingPatterns(
 
   patterns.add<AddIPattern, SubIPattern, MulIPattern, DivSIPattern,
                DivUIPattern, MaxSIPattern, MaxUIPattern, MinSIPattern,
-               MinUIPattern, SIToFPPattern, UIToFPPattern>(
-      patterns.getContext(), options);
+               MinUIPattern, SIToFPPattern, UIToFPPattern, IndexCastSIPattern,
+               IndexCastUIPattern>(patterns.getContext(), options);
 }
 
 } // namespace mlir::arith

diff  --git a/mlir/test/Dialect/Linalg/int-narrowing.mlir b/mlir/test/Dialect/Linalg/int-narrowing.mlir
new file mode 100644
index 0000000000000..8063d504597a3
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/int-narrowing.mlir
@@ -0,0 +1,147 @@
+// RUN: mlir-opt --arith-int-narrowing="int-bitwidths-supported=1,8,16,32" \
+// RUN:          --verify-diagnostics %s | FileCheck %s
+
+// Check that we can calculate `linalg.index` value bounds and use them to
+// optimize index casts.
+
+//===----------------------------------------------------------------------===//
+// arith.index_cast
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func @linalg_indexcast_dim_0_i8
+// CHECK:         %[[IDX:.+]] = linalg.index 0 : index
+// CHECK-NEXT:    %[[INT:.+]] = arith.index_cast %[[IDX]] : index to i8
+// CHECK-NEXT:    %[[FP:.+]]  = arith.sitofp %[[INT]] : i8 to f16
+// CHECK-NEXT:    linalg.yield %[[FP]] : f16
+func.func @linalg_indexcast_dim_0_i8(%arg0: tensor<f16>) -> tensor<128xf16> {
+  %init = tensor.empty() : tensor<128xf16>
+  %res = linalg.generic {
+      indexing_maps = [affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>],
+      iterator_types = ["parallel"]
+    }
+    ins(%arg0 : tensor<f16>)
+    outs(%init : tensor<128xf16>) {
+  ^bb0(%in: f16, %out: f16):
+    %idx = linalg.index 0 : index
+    %int = arith.index_cast %idx : index to i64
+    %fp = arith.sitofp %int : i64 to f16
+    linalg.yield %fp : f16
+  } -> tensor<128xf16>
+
+  return %res : tensor<128xf16>
+}
+
+// CHECK-LABEL: func @linalg_indexcast_dim_1_i16
+// CHECK:         %[[IDX:.+]] = linalg.index 1 : index
+// CHECK-NEXT:    %[[INT:.+]] = arith.index_cast %[[IDX]] : index to i16
+// CHECK-NEXT:    %[[FP:.+]]  = arith.sitofp %[[INT]] : i16 to f16
+// CHECK-NEXT:    linalg.yield %[[FP]] : f16
+func.func @linalg_indexcast_dim_1_i16(%arg0: tensor<f16>, %arg1: tensor<?x129xf16>) -> tensor<?x129xf16> {
+  %res = linalg.generic {
+      indexing_maps = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>],
+      iterator_types = ["parallel", "parallel"]
+    }
+    ins(%arg0 : tensor<f16>)
+    outs(%arg1 : tensor<?x129xf16>) {
+  ^bb0(%in: f16, %out: f16):
+    %idx = linalg.index 1 : index
+    %int = arith.index_cast %idx : index to i64
+    %fp = arith.sitofp %int : i64 to f16
+    linalg.yield %fp : f16
+  } -> tensor<?x129xf16>
+
+  return %res : tensor<?x129xf16>
+}
+
+// CHECK-LABEL: func @linalg_indexcast_dynamic_dim_i64
+// CHECK:         %[[IDX:.+]] = linalg.index 0 : index
+// CHECK-NEXT:    %[[INT:.+]] = arith.index_cast %[[IDX]] : index to i64
+// CHECK-NEXT:    %[[FP:.+]]  = arith.sitofp %[[INT]] : i64 to f16
+// CHECK-NEXT:    linalg.yield %[[FP]] : f16
+func.func @linalg_indexcast_dynamic_dim_i64(%arg0: tensor<f16>, %arg1: tensor<?xf16>) -> tensor<?xf16> {
+  %res = linalg.generic {
+      indexing_maps = [affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>],
+      iterator_types = ["parallel"]
+    }
+    ins(%arg0 : tensor<f16>)
+    outs(%arg1 : tensor<?xf16>) {
+  ^bb0(%in: f16, %out: f16):
+    %idx = linalg.index 0 : index
+    %int = arith.index_cast %idx : index to i64
+    %fp = arith.sitofp %int : i64 to f16
+    linalg.yield %fp : f16
+  } -> tensor<?xf16>
+
+  return %res : tensor<?xf16>
+}
+
+//===----------------------------------------------------------------------===//
+// arith.index_castui
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func @linalg_indexcastui_dim_0_i8
+// CHECK:         %[[IDX:.+]] = linalg.index 0 : index
+// CHECK-NEXT:    %[[INT:.+]] = arith.index_castui %[[IDX]] : index to i8
+// CHECK-NEXT:    %[[FP:.+]]  = arith.uitofp %[[INT]] : i8 to f16
+// CHECK-NEXT:    linalg.yield %[[FP]] : f16
+func.func @linalg_indexcastui_dim_0_i8(%arg0: tensor<f16>) -> tensor<256xf16> {
+  %init = tensor.empty() : tensor<256xf16>
+  %res = linalg.generic {
+      indexing_maps = [affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>],
+      iterator_types = ["parallel"]
+    }
+    ins(%arg0 : tensor<f16>)
+    outs(%init : tensor<256xf16>) {
+  ^bb0(%in: f16, %out: f16):
+    %idx = linalg.index 0 : index
+    %int = arith.index_castui %idx : index to i64
+    %fp = arith.uitofp %int : i64 to f16
+    linalg.yield %fp : f16
+  } -> tensor<256xf16>
+
+  return %res : tensor<256xf16>
+}
+
+// CHECK-LABEL: func @linalg_indexcastui_dim_1_i16
+// CHECK:         %[[IDX:.+]] = linalg.index 1 : index
+// CHECK-NEXT:    %[[INT:.+]] = arith.index_castui %[[IDX]] : index to i16
+// CHECK-NEXT:    %[[FP:.+]]  = arith.uitofp %[[INT]] : i16 to f16
+// CHECK-NEXT:    linalg.yield %[[FP]] : f16
+func.func @linalg_indexcastui_dim_1_i16(%arg0: tensor<f16>, %arg1: tensor<?x257xf16>) -> tensor<?x257xf16> {
+  %res = linalg.generic {
+      indexing_maps = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>],
+      iterator_types = ["parallel", "parallel"]
+    }
+    ins(%arg0 : tensor<f16>)
+    outs(%arg1 : tensor<?x257xf16>) {
+  ^bb0(%in: f16, %out: f16):
+    %idx = linalg.index 1 : index
+    %int = arith.index_castui %idx : index to i64
+    %fp = arith.uitofp %int : i64 to f16
+    linalg.yield %fp : f16
+  } -> tensor<?x257xf16>
+
+  return %res : tensor<?x257xf16>
+}
+
+// CHECK-LABEL: func @linalg_indexcastui_dynamic_dim_i64
+// CHECK:         %[[IDX:.+]] = linalg.index 0 : index
+// CHECK-NEXT:    %[[INT:.+]] = arith.index_castui %[[IDX]] : index to i64
+// CHECK-NEXT:    %[[FP:.+]]  = arith.uitofp %[[INT]] : i64 to f16
+// CHECK-NEXT:    linalg.yield %[[FP]] : f16
+func.func @linalg_indexcastui_dynamic_dim_i64(%arg0: tensor<f16>, %arg1: tensor<?xf16>) -> tensor<?xf16> {
+  %res = linalg.generic {
+      indexing_maps = [affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>],
+      iterator_types = ["parallel"]
+    }
+    ins(%arg0 : tensor<f16>)
+    outs(%arg1 : tensor<?xf16>) {
+  ^bb0(%in: f16, %out: f16):
+    %idx = linalg.index 0 : index
+    %int = arith.index_castui %idx : index to i64
+    %fp = arith.uitofp %int : i64 to f16
+    linalg.yield %fp : f16
+  } -> tensor<?xf16>
+
+  return %res : tensor<?xf16>
+}


        


More information about the Mlir-commits mailing list