[Mlir-commits] [mlir] 9ef3146 - [mlir][index] Add shl, shrs, and shru ops
Jeff Niu
llvmlistbot at llvm.org
Thu Nov 3 16:29:10 PDT 2022
Author: Jeff Niu
Date: 2022-11-03T16:29:04-07:00
New Revision: 9ef31465114dcd226c27f43b73f221e89a4fa83d
URL: https://github.com/llvm/llvm-project/commit/9ef31465114dcd226c27f43b73f221e89a4fa83d
DIFF: https://github.com/llvm/llvm-project/commit/9ef31465114dcd226c27f43b73f221e89a4fa83d.diff
LOG: [mlir][index] Add shl, shrs, and shru ops
This patch adds the left shift, signed right shift, and unsigned right
shift operations to the index dialects with folders and LLVM lowerings.
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D137349
Added:
Modified:
mlir/include/mlir/Dialect/Index/IR/IndexOps.td
mlir/lib/Conversion/IndexToLLVM/IndexToLLVM.cpp
mlir/lib/Dialect/Index/IR/IndexOps.cpp
mlir/test/Conversion/IndexToLLVM/index-to-llvm.mlir
mlir/test/Dialect/Index/index-canonicalize.mlir
mlir/test/Dialect/Index/index-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
index 0896f21954603..29f4c1eb151c5 100644
--- a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
+++ b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
@@ -280,6 +280,69 @@ def Index_MaxUOp : IndexBinaryOp<"maxu"> {
}];
}
+//===----------------------------------------------------------------------===//
+// ShlOp
+//===----------------------------------------------------------------------===//
+
+def Index_ShlOp : IndexBinaryOp<"shl"> {
+ let summary = "index shift left";
+ let description = [{
+ The `index.shl` operation shifts an index value to the left by a variable
+ amount. The low order bits are filled with zeroes. The RHS operand is always
+ treated as unsigned. If the RHS operand is equal to or greater than the
+ index bitwidth, the operation is undefined.
+
+ Example:
+
+ ```mlir
+ // c = a << b
+ %c = index.shl %a, %b
+ ```
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// ShrSOp
+//===----------------------------------------------------------------------===//
+
+def Index_ShrSOp : IndexBinaryOp<"shrs"> {
+ let summary = "signed index shift right";
+ let description = [{
+ The `index.shrs` operation shifts an index value to the right by a variable
+ amount. The LHS operand is treated as signed. The high order bits are filled
+ with copies of the most significant bit. If the RHS operand is equal to or
+ greater than the index bitwidth, the operation is undefined.
+
+ Example:
+
+ ```mlir
+ // c = a >> b
+ %c = index.shrs %a, %b
+ ```
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// ShrUOp
+//===----------------------------------------------------------------------===//
+
+def Index_ShrUOp : IndexBinaryOp<"shru"> {
+ let summary = "unsigned index shift right";
+ let description = [{
+ The `index.shru` operation shifts an index value to the right by a variable
+ amount. The LHS operand is treated as unsigned. The high order bits are
+ filled with zeroes. If the RHS operand is equal to or greater than the index
+ bitwidth, the operation is undefined.
+
+ Example:
+
+ ```mlir
+ // c = a >> b
+ %c = index.shru %a, %b
+ ```
+ }];
+}
+
//===----------------------------------------------------------------------===//
// CastSOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/IndexToLLVM/IndexToLLVM.cpp b/mlir/lib/Conversion/IndexToLLVM/IndexToLLVM.cpp
index 844c57a74a198..4461d5121ef01 100644
--- a/mlir/lib/Conversion/IndexToLLVM/IndexToLLVM.cpp
+++ b/mlir/lib/Conversion/IndexToLLVM/IndexToLLVM.cpp
@@ -268,6 +268,11 @@ using ConvertIndexMaxS =
mlir::OneToOneConvertToLLVMPattern<MaxSOp, LLVM::SMaxOp>;
using ConvertIndexMaxU =
mlir::OneToOneConvertToLLVMPattern<MaxUOp, LLVM::UMaxOp>;
+using ConvertIndexShl = mlir::OneToOneConvertToLLVMPattern<ShlOp, LLVM::ShlOp>;
+using ConvertIndexShrS =
+ mlir::OneToOneConvertToLLVMPattern<ShrSOp, LLVM::AShrOp>;
+using ConvertIndexShrU =
+ mlir::OneToOneConvertToLLVMPattern<ShrUOp, LLVM::LShrOp>;
using ConvertIndexBoolConstant =
mlir::OneToOneConvertToLLVMPattern<BoolConstantOp, LLVM::ConstantOp>;
@@ -290,6 +295,9 @@ void index::populateIndexToLLVMConversionPatterns(
ConvertIndexRemU,
ConvertIndexMaxS,
ConvertIndexMaxU,
+ ConvertIndexShl,
+ ConvertIndexShrS,
+ ConvertIndexShrU,
ConvertIndexCeilDivS,
ConvertIndexCeilDivU,
ConvertIndexFloorDivS,
diff --git a/mlir/lib/Dialect/Index/IR/IndexOps.cpp b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
index fcbb076f2e16f..241fa416eddab 100644
--- a/mlir/lib/Dialect/Index/IR/IndexOps.cpp
+++ b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
@@ -62,17 +62,19 @@ Operation *IndexDialect::materializeConstant(OpBuilder &b, Attribute value,
/// the integer result, which in turn must satisfy the above property.
static OpFoldResult foldBinaryOpUnchecked(
ArrayRef<Attribute> operands,
- function_ref<APInt(const APInt &, const APInt &)> calculate) {
+ function_ref<Optional<APInt>(const APInt &, const APInt &)> calculate) {
assert(operands.size() == 2 && "binary operation expected 2 operands");
auto lhs = dyn_cast_if_present<IntegerAttr>(operands[0]);
auto rhs = dyn_cast_if_present<IntegerAttr>(operands[1]);
if (!lhs || !rhs)
return {};
- APInt result = calculate(lhs.getValue(), rhs.getValue());
- assert(result.trunc(32) ==
+ Optional<APInt> result = calculate(lhs.getValue(), rhs.getValue());
+ if (!result)
+ return {};
+ assert(result->trunc(32) ==
calculate(lhs.getValue().trunc(32), rhs.getValue().trunc(32)));
- return IntegerAttr::get(IndexType::get(lhs.getContext()), std::move(result));
+ return IntegerAttr::get(IndexType::get(lhs.getContext()), std::move(*result));
}
/// Fold an index operation only if the truncated 64-bit result matches the
@@ -284,6 +286,50 @@ OpFoldResult MaxUOp::fold(ArrayRef<Attribute> operands) {
});
}
+//===----------------------------------------------------------------------===//
+// ShlOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult ShlOp::fold(ArrayRef<Attribute> operands) {
+ return foldBinaryOpUnchecked(
+ operands, [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
+ // We cannot fold if the RHS is greater than or equal to 32 because
+ // this would be UB in 32-bit systems but not on 64-bit systems. RHS is
+ // already treated as unsigned.
+ if (rhs.uge(32))
+ return {};
+ return lhs << rhs;
+ });
+}
+
+//===----------------------------------------------------------------------===//
+// ShrSOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult ShrSOp::fold(ArrayRef<Attribute> operands) {
+ return foldBinaryOpChecked(
+ operands, [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
+ // Don't fold if RHS is greater than or equal to 32.
+ if (rhs.uge(32))
+ return {};
+ return lhs.ashr(rhs);
+ });
+}
+
+//===----------------------------------------------------------------------===//
+// ShrUOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult ShrUOp::fold(ArrayRef<Attribute> operands) {
+ return foldBinaryOpChecked(
+ operands, [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
+ // Don't fold if RHS is greater than or equal to 32.
+ if (rhs.uge(32))
+ return {};
+ return lhs.lshr(rhs);
+ });
+}
+
//===----------------------------------------------------------------------===//
// CastSOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/IndexToLLVM/index-to-llvm.mlir b/mlir/test/Conversion/IndexToLLVM/index-to-llvm.mlir
index ee8e6629aa719..c6b2273fa1f3f 100644
--- a/mlir/test/Conversion/IndexToLLVM/index-to-llvm.mlir
+++ b/mlir/test/Conversion/IndexToLLVM/index-to-llvm.mlir
@@ -22,8 +22,14 @@ func.func @trivial_ops(%a: index, %b: index) {
%7 = index.maxs %a, %b
// CHECK: llvm.intr.umax
%8 = index.maxu %a, %b
+ // CHECK: llvm.shl
+ %9 = index.shl %a, %b
+ // CHECK: llvm.ashr
+ %10 = index.shrs %a, %b
+ // CHECK: llvm.lshr
+ %11 = index.shru %a, %b
// CHECK: llvm.mlir.constant(true
- %9 = index.bool.constant true
+ %12 = index.bool.constant true
return
}
diff --git a/mlir/test/Dialect/Index/index-canonicalize.mlir b/mlir/test/Dialect/Index/index-canonicalize.mlir
index f9b33f88a1a26..288593f64c3f7 100644
--- a/mlir/test/Dialect/Index/index-canonicalize.mlir
+++ b/mlir/test/Dialect/Index/index-canonicalize.mlir
@@ -279,6 +279,111 @@ func.func @maxu() -> index {
return %0 : index
}
+// CHECK-LABEL: @shl
+func.func @shl() -> index {
+ %lhs = index.constant 128
+ %rhs = index.constant 2
+ // CHECK: %[[A:.*]] = index.constant 512
+ %0 = index.shl %lhs, %rhs
+ // CHECK: return %[[A]]
+ return %0 : index
+}
+
+// CHECK-LABEL: @shl_32
+func.func @shl_32() -> index {
+ %lhs = index.constant 1
+ %rhs = index.constant 32
+ // CHECK: index.shl
+ %0 = index.shl %lhs, %rhs
+ return %0 : index
+}
+
+// CHECK-LABEL: @shl_edge
+func.func @shl_edge() -> index {
+ %lhs = index.constant 4000000000
+ %rhs = index.constant 31
+ // CHECK: %[[A:.*]] = index.constant 858{{[0-9]+}}
+ %0 = index.shl %lhs, %rhs
+ // CHECK: return %[[A]]
+ return %0 : index
+}
+
+// CHECK-LABEL: @shrs
+func.func @shrs() -> index {
+ %lhs = index.constant 128
+ %rhs = index.constant 2
+ // CHECK: %[[A:.*]] = index.constant 32
+ %0 = index.shrs %lhs, %rhs
+ // CHECK: return %[[A]]
+ return %0 : index
+}
+
+// CHECK-LABEL: @shrs_32
+func.func @shrs_32() -> index {
+ %lhs = index.constant 4000000000000
+ %rhs = index.constant 32
+ // CHECK: index.shrs
+ %0 = index.shrs %lhs, %rhs
+ return %0 : index
+}
+
+// CHECK-LABEL: @shrs_nofold
+func.func @shrs_nofold() -> index {
+ %lhs = index.constant 0x100000000
+ %rhs = index.constant 1
+ // CHECK: index.shrs
+ %0 = index.shrs %lhs, %rhs
+ return %0 : index
+}
+
+// CHECK-LABEL: @shrs_edge
+func.func @shrs_edge() -> index {
+ %lhs = index.constant 0x10000000000
+ %rhs = index.constant 3
+ // CHECK: %[[A:.*]] = index.constant 137{{[0-9]+}}
+ %0 = index.shrs %lhs, %rhs
+ // CHECK: return %[[A]]
+ return %0 : index
+}
+
+// CHECK-LABEL: @shru
+func.func @shru() -> index {
+ %lhs = index.constant 128
+ %rhs = index.constant 2
+ // CHECK: %[[A:.*]] = index.constant 32
+ %0 = index.shru %lhs, %rhs
+ // CHECK: return %[[A]]
+ return %0 : index
+}
+
+// CHECK-LABEL: @shru_32
+func.func @shru_32() -> index {
+ %lhs = index.constant 4000000000000
+ %rhs = index.constant 32
+ // CHECK: index.shru
+ %0 = index.shru %lhs, %rhs
+ return %0 : index
+}
+
+// CHECK-LABEL: @shru_nofold
+func.func @shru_nofold() -> index {
+ %lhs = index.constant 0x100000000
+ %rhs = index.constant 1
+ // CHECK: index.shru
+ %0 = index.shru %lhs, %rhs
+ return %0 : index
+}
+
+// CHECK-LABEL: @shru_edge
+func.func @shru_edge() -> index {
+ %lhs = index.constant 0x10000000000
+ %rhs = index.constant 3
+ // CHECK: %[[A:.*]] = index.constant 137{{[0-9]+}}
+ %0 = index.shru %lhs, %rhs
+ // CHECK: return %[[A]]
+ return %0 : index
+}
+
// CHECK-LABEL: @cmp
func.func @cmp() -> (i1, i1, i1, i1) {
%a = index.constant 0
diff --git a/mlir/test/Dialect/Index/index-ops.mlir b/mlir/test/Dialect/Index/index-ops.mlir
index 2176efe337309..d1a409780cd51 100644
--- a/mlir/test/Dialect/Index/index-ops.mlir
+++ b/mlir/test/Dialect/Index/index-ops.mlir
@@ -27,6 +27,12 @@ func.func @binary_ops(%a: index, %b: index) {
%10 = index.maxs %a, %b
// CHECK-NEXT: index.maxu %[[A]], %[[B]]
%11 = index.maxu %a, %b
+ // CHECK-NEXT: index.shl %[[A]], %[[B]]
+ %12 = index.shl %a, %b
+ // CHECK-NEXT: index.shrs %[[A]], %[[B]]
+ %13 = index.shrs %a, %b
+ // CHECK-NEXT: index.shru %[[A]], %[[B]]
+ %14 = index.shru %a, %b
return
}
More information about the Mlir-commits
mailing list