[Mlir-commits] [mlir] f8fafe9 - [mlir] Add unsigned version of index_cast
Thomas Raoux
llvmlistbot at llvm.org
Mon Oct 3 11:51:34 PDT 2022
Author: Thomas Raoux
Date: 2022-10-03T18:51:15Z
New Revision: f8fafe99a4ee2c047acf5a79d1033da8024f1f26
URL: https://github.com/llvm/llvm-project/commit/f8fafe99a4ee2c047acf5a79d1033da8024f1f26
DIFF: https://github.com/llvm/llvm-project/commit/f8fafe99a4ee2c047acf5a79d1033da8024f1f26.diff
LOG: [mlir] Add unsigned version of index_cast
This is required to be able to cast integer type to a potential larger index using zero-extend cast.
There is a larger change under discussion to move index ops in a separate dialect: https://discourse.llvm.org/t/rfc-index-dialect/65540/
Based on timing of this work this patch can be included as part of this effort but as a short term solution we may want to add this op to arithmetic dialect for now in order to fill the gap.
Reviewed By: Mogball, stellaraccident
Differential Revision: https://reviews.llvm.org/D135089
Added:
Modified:
mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
mlir/lib/Dialect/Arith/IR/ArithOps.cpp
mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
mlir/test/Dialect/Arith/canonicalize.mlir
mlir/test/Dialect/Arith/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index c143f9a813bdf..b59d9fd03c683 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1036,6 +1036,25 @@ def Arith_IndexCastOp
let hasCanonicalizer = 1;
}
+//===----------------------------------------------------------------------===//
+// IndexCastUIOp
+//===----------------------------------------------------------------------===//
+
+def Arith_IndexCastUIOp
+ : Arith_CastOp<"index_castui", IndexCastTypeConstraint, IndexCastTypeConstraint,
+ [DeclareOpInterfaceMethods<InferIntRangeInterface>]> {
+ let summary = "unsigned cast between index and integer types";
+ let description = [{
+ Casts between scalar or vector integers and corresponding 'index' scalar or
+ vectors. Index is an integer of platform-specific bit width. If casting to
+ a wider integer, the value is zero-extended. If casting to a narrower
+ integer, the value is truncated.
+ }];
+
+ let hasFolder = 1;
+ let hasCanonicalizer = 1;
+}
+
//===----------------------------------------------------------------------===//
// BitcastOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index c77edd8d242a0..1610e5cee8b7d 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -104,14 +104,20 @@ struct ConstantOpLowering : public ConvertOpToLLVMPattern<arith::ConstantOp> {
/// becomes an integer. If the bit width of the source and target integer
/// types is the same, just erase the cast. If the target type is wider,
/// sign-extend the value, otherwise truncate it.
-struct IndexCastOpLowering : public ConvertOpToLLVMPattern<arith::IndexCastOp> {
- using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+template <typename OpTy, typename ExtCastTy>
+struct IndexCastOpLowering : public ConvertOpToLLVMPattern<OpTy> {
+ using ConvertOpToLLVMPattern<OpTy>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
+ matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
+using IndexCastOpSILowering =
+ IndexCastOpLowering<arith::IndexCastOp, LLVM::SExtOp>;
+using IndexCastOpUILowering =
+ IndexCastOpLowering<arith::IndexCastUIOp, LLVM::ZExtOp>;
+
struct AddUICarryOpLowering
: public ConvertOpToLLVMPattern<arith::AddUICarryOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
@@ -155,14 +161,15 @@ ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
// IndexCastOpLowering
//===----------------------------------------------------------------------===//
-LogicalResult IndexCastOpLowering::matchAndRewrite(
- arith::IndexCastOp op, OpAdaptor adaptor,
+template <typename OpTy, typename ExtCastTy>
+LogicalResult IndexCastOpLowering<OpTy, ExtCastTy>::matchAndRewrite(
+ OpTy op, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Type resultType = op.getResult().getType();
Type targetElementType =
- typeConverter->convertType(getElementTypeOrSelf(resultType));
+ this->typeConverter->convertType(getElementTypeOrSelf(resultType));
Type sourceElementType =
- typeConverter->convertType(getElementTypeOrSelf(op.getIn()));
+ this->typeConverter->convertType(getElementTypeOrSelf(op.getIn()));
unsigned targetBits = targetElementType.getIntOrFloatBitWidth();
unsigned sourceBits = sourceElementType.getIntOrFloatBitWidth();
@@ -174,13 +181,12 @@ LogicalResult IndexCastOpLowering::matchAndRewrite(
// Handle the scalar and 1D vector cases.
Type operandType = adaptor.getIn().getType();
if (!operandType.isa<LLVM::LLVMArrayType>()) {
- Type targetType = typeConverter->convertType(resultType);
+ Type targetType = this->typeConverter->convertType(resultType);
if (targetBits < sourceBits)
rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType,
adaptor.getIn());
else
- rewriter.replaceOpWithNewOp<LLVM::SExtOp>(op, targetType,
- adaptor.getIn());
+ rewriter.replaceOpWithNewOp<ExtCastTy>(op, targetType, adaptor.getIn());
return success();
}
@@ -188,15 +194,15 @@ LogicalResult IndexCastOpLowering::matchAndRewrite(
return rewriter.notifyMatchFailure(op, "expected vector result type");
return LLVM::detail::handleMultidimensionalVectors(
- op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
+ op.getOperation(), adaptor.getOperands(), *(this->getTypeConverter()),
[&](Type llvm1DVectorTy, ValueRange operands) -> Value {
- OpAdaptor adaptor(operands);
+ typename OpTy::Adaptor adaptor(operands);
if (targetBits < sourceBits) {
return rewriter.create<LLVM::TruncOp>(op.getLoc(), llvm1DVectorTy,
adaptor.getIn());
}
- return rewriter.create<LLVM::SExtOp>(op.getLoc(), llvm1DVectorTy,
- adaptor.getIn());
+ return rewriter.create<ExtCastTy>(op.getLoc(), llvm1DVectorTy,
+ adaptor.getIn());
},
rewriter);
}
@@ -366,7 +372,8 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
ExtUIOpLowering,
FPToSIOpLowering,
FPToUIOpLowering,
- IndexCastOpLowering,
+ IndexCastOpSILowering,
+ IndexCastOpUILowering,
MaxFOpLowering,
MaxSIOpLowering,
MaxUIOpLowering,
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 034abcc7eb3ea..24bead82c18bb 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -990,6 +990,7 @@ void mlir::arith::populateArithToSPIRVPatterns(
TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
+ TypeCastingOpPattern<arith::IndexCastUIOp, spirv::UConvertOp>,
TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
CmpIOpBooleanPattern, CmpIOpPattern,
CmpFOpNanNonePattern, CmpFOpPattern,
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index f57c4d88e8518..2cb5a553634bb 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -151,6 +151,21 @@ def IndexCastOfIndexCast :
def IndexCastOfExtSI :
Pat<(Arith_IndexCastOp (Arith_ExtSIOp $x)), (Arith_IndexCastOp $x)>;
+//===----------------------------------------------------------------------===//
+// IndexCastUIOp
+//===----------------------------------------------------------------------===//
+
+// index_castui(index_castui(x)) -> x, if dstType == srcType.
+def IndexCastUIOfIndexCastUI :
+ Pat<(Arith_IndexCastUIOp:$res (Arith_IndexCastUIOp $x)),
+ (replaceWithValue $x),
+ [(Constraint<CPred<"$0.getType() == $1.getType()">> $res, $x)]>;
+
+// index_castui(extui(x)) -> index_castui(x)
+def IndexCastUIOfExtUI :
+ Pat<(Arith_IndexCastUIOp (Arith_ExtUIOp $x)), (Arith_IndexCastUIOp $x)>;
+
+
//===----------------------------------------------------------------------===//
// BitcastOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index d62e3f34459e7..190a1ef72d5bd 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1261,8 +1261,7 @@ OpFoldResult arith::FPToSIOp::fold(ArrayRef<Attribute> operands) {
// IndexCastOp
//===----------------------------------------------------------------------===//
-bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
- TypeRange outputs) {
+static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs) {
if (!areValidCastInputsAndOutputs(inputs, outputs))
return false;
@@ -1275,6 +1274,11 @@ bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
(srcType.isSignlessInteger() && dstType.isIndex());
}
+bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
+ TypeRange outputs) {
+ return areIndexCastCompatible(inputs, outputs);
+}
+
OpFoldResult arith::IndexCastOp::fold(ArrayRef<Attribute> operands) {
// index_cast(constant) -> constant
// A little hack because we go through int. Otherwise, the size of the
@@ -1290,6 +1294,30 @@ void arith::IndexCastOp::getCanonicalizationPatterns(
patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
}
+//===----------------------------------------------------------------------===//
+// IndexCastUIOp
+//===----------------------------------------------------------------------===//
+
+bool arith::IndexCastUIOp::areCastCompatible(TypeRange inputs,
+ TypeRange outputs) {
+ return areIndexCastCompatible(inputs, outputs);
+}
+
+OpFoldResult arith::IndexCastUIOp::fold(ArrayRef<Attribute> operands) {
+ // index_castui(constant) -> constant
+ // A little hack because we go through int. Otherwise, the size of the
+ // constant might need to change.
+ if (auto value = operands[0].dyn_cast_or_null<IntegerAttr>())
+ return IntegerAttr::get(getType(), value.getUInt());
+
+ return {};
+}
+
+void arith::IndexCastUIOp::getCanonicalizationPatterns(
+ RewritePatternSet &patterns, MLIRContext *context) {
+ patterns.add<IndexCastUIOfIndexCastUI, IndexCastUIOfExtUI>(context);
+}
+
//===----------------------------------------------------------------------===//
// BitcastOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
index e59469cd5f65e..243c3ef50faad 100644
--- a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
@@ -466,13 +466,18 @@ void arith::MinUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
// ExtUIOp
//===----------------------------------------------------------------------===//
+static ConstantIntRanges extUIRange(const ConstantIntRanges &range,
+ Type destType) {
+ unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
+ APInt smin = range.umin().zext(destWidth);
+ APInt smax = range.umax().zext(destWidth);
+ return ConstantIntRanges::fromSigned(smin, smax);
+}
+
void arith::ExtUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
Type destType = getResult().getType();
- unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
- APInt umin = argRanges[0].umin().zext(destWidth);
- APInt umax = argRanges[0].umax().zext(destWidth);
- setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax));
+ setResultRange(getResult(), extUIRange(argRanges[0], destType));
}
//===----------------------------------------------------------------------===//
@@ -559,6 +564,25 @@ void arith::IndexCastOp::inferResultRanges(
setResultRange(getResult(), argRanges[0]);
}
+//===----------------------------------------------------------------------===//
+// IndexCastUIOp
+//===----------------------------------------------------------------------===//
+
+void arith::IndexCastUIOp::inferResultRanges(
+ ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
+ Type sourceType = getOperand().getType();
+ Type destType = getResult().getType();
+ unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType);
+ unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
+
+ if (srcWidth < destWidth)
+ setResultRange(getResult(), extUIRange(argRanges[0], destType));
+ else if (srcWidth > destWidth)
+ setResultRange(getResult(), truncIRange(argRanges[0], destType));
+ else
+ setResultRange(getResult(), argRanges[0]);
+}
+
//===----------------------------------------------------------------------===//
// CmpIOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index c476d43627275..05706d89de742 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -92,6 +92,23 @@ func.func @vector_index_cast(%arg0: vector<2xindex>, %arg1: vector<2xi1>) {
return
}
+func.func @index_castui(%arg0: index, %arg1: i1) {
+// CHECK: = llvm.trunc %0 : i{{.*}} to i1
+ %0 = arith.index_castui %arg0: index to i1
+// CHECK-NEXT: = llvm.zext %arg1 : i1 to i{{.*}}
+ %1 = arith.index_castui %arg1: i1 to index
+ return
+}
+
+// CHECK-LABEL: @vector_index_castui
+func.func @vector_index_castui(%arg0: vector<2xindex>, %arg1: vector<2xi1>) {
+// CHECK: = llvm.trunc %{{.*}} : vector<2xi{{.*}}> to vector<2xi1>
+ %0 = arith.index_castui %arg0: vector<2xindex> to vector<2xi1>
+// CHECK-NEXT: = llvm.zext %{{.*}} : vector<2xi1> to vector<2xi{{.*}}>
+ %1 = arith.index_castui %arg1: vector<2xi1> to vector<2xindex>
+ return
+}
+
// Checking conversion of signed integer types to floating point.
// CHECK-LABEL: @sitofp
func.func @sitofp(%arg0 : i32, %arg1 : i64) {
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index 16a7967c3a340..bd5238f3629f6 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -690,7 +690,35 @@ func.func @index_cast3(%arg0: i32) {
// CHECK-LABEL: index_cast4
func.func @index_cast4(%arg0: index) {
- // CHECK-NOT: spirv.SConvert
+ // CHECK-NOT: spirv.UConvert
+ %0 = arith.index_cast %arg0 : index to i32
+ return
+}
+
+// CHECK-LABEL: index_castui1
+func.func @index_castui1(%arg0: i16) {
+ // CHECK: spirv.UConvert %{{.+}} : i16 to i32
+ %0 = arith.index_castui %arg0 : i16 to index
+ return
+}
+
+// CHECK-LABEL: index_castui2
+func.func @index_castui2(%arg0: index) {
+ // CHECK: spirv.UConvert %{{.+}} : i32 to i16
+ %0 = arith.index_castui %arg0 : index to i16
+ return
+}
+
+// CHECK-LABEL: index_castui3
+func.func @index_castui3(%arg0: i32) {
+ // CHECK-NOT: spirv.UConvert
+ %0 = arith.index_castui %arg0 : i32 to index
+ return
+}
+
+// CHECK-LABEL: index_castui4
+func.func @index_castui4(%arg0: index) {
+ // CHECK-NOT: spirv.UConvert
%0 = arith.index_cast %arg0 : index to i32
return
}
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 632e7af4a26a3..be680acea733c 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -308,6 +308,15 @@ func.func @indexCastOfSignExtend(%arg0: i8) -> index {
return %idx : index
}
+// CHECK-LABEL: @indexCastUIOfUnsignedExtend
+// CHECK: %[[res:.+]] = arith.index_castui %arg0 : i8 to index
+// CHECK: return %[[res]]
+func.func @indexCastUIOfUnsignedExtend(%arg0: i8) -> index {
+ %ext = arith.extui %arg0 : i8 to i16
+ %idx = arith.index_castui %ext : i16 to index
+ return %idx : index
+}
+
// CHECK-LABEL: @signExtendConstant
// CHECK: %[[cres:.+]] = arith.constant -2 : i16
// CHECK: return %[[cres]]
diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index 56e17c798f831..c34850ff6e305 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -793,6 +793,55 @@ func.func @test_index_cast_scalable_vector1(%arg0 : vector<[8]xindex>) -> vector
return %0 : vector<[8]xi64>
}
+
+// CHECK-LABEL: test_index_castui0
+func.func @test_index_castui0(%arg0 : i32) -> index {
+ %0 = arith.index_castui %arg0 : i32 to index
+ return %0 : index
+}
+
+// CHECK-LABEL: test_index_castui_tensor0
+func.func @test_index_castui_tensor0(%arg0 : tensor<8x8xi32>) -> tensor<8x8xindex> {
+ %0 = arith.index_castui %arg0 : tensor<8x8xi32> to tensor<8x8xindex>
+ return %0 : tensor<8x8xindex>
+}
+
+// CHECK-LABEL: test_index_castui_vector0
+func.func @test_index_castui_vector0(%arg0 : vector<8xi32>) -> vector<8xindex> {
+ %0 = arith.index_castui %arg0 : vector<8xi32> to vector<8xindex>
+ return %0 : vector<8xindex>
+}
+
+// CHECK-LABEL: test_index_castui_scalable_vector0
+func.func @test_index_castui_scalable_vector0(%arg0 : vector<[8]xi32>) -> vector<[8]xindex> {
+ %0 = arith.index_castui %arg0 : vector<[8]xi32> to vector<[8]xindex>
+ return %0 : vector<[8]xindex>
+}
+
+// CHECK-LABEL: test_indexui_cast1
+func.func @test_indexui_cast1(%arg0 : index) -> i64 {
+ %0 = arith.index_castui %arg0 : index to i64
+ return %0 : i64
+}
+
+// CHECK-LABEL: test_index_castui_tensor1
+func.func @test_index_castui_tensor1(%arg0 : tensor<8x8xindex>) -> tensor<8x8xi64> {
+ %0 = arith.index_castui %arg0 : tensor<8x8xindex> to tensor<8x8xi64>
+ return %0 : tensor<8x8xi64>
+}
+
+// CHECK-LABEL: test_index_castui_vector1
+func.func @test_index_castui_vector1(%arg0 : vector<8xindex>) -> vector<8xi64> {
+ %0 = arith.index_castui %arg0 : vector<8xindex> to vector<8xi64>
+ return %0 : vector<8xi64>
+}
+
+// CHECK-LABEL: test_index_castui_scalable_vector1
+func.func @test_index_castui_scalable_vector1(%arg0 : vector<[8]xindex>) -> vector<[8]xi64> {
+ %0 = arith.index_castui %arg0 : vector<[8]xindex> to vector<[8]xi64>
+ return %0 : vector<[8]xi64>
+}
+
// CHECK-LABEL: test_bitcast0
func.func @test_bitcast0(%arg0 : i64) -> f64 {
%0 = arith.bitcast %arg0 : i64 to f64
More information about the Mlir-commits
mailing list