[Mlir-commits] [mlir] 92bcb8c - [mlir][arith] Add `index_cast` and `index_castui` support to WIE

Jakub Kuderski llvmlistbot at llvm.org
Thu Nov 17 11:05:43 PST 2022


Author: Jakub Kuderski
Date: 2022-11-17T14:04:17-05:00
New Revision: 92bcb8ccbb8b55458a6d96b3b16ff2abc138b88c

URL: https://github.com/llvm/llvm-project/commit/92bcb8ccbb8b55458a6d96b3b16ff2abc138b88c
DIFF: https://github.com/llvm/llvm-project/commit/92bcb8ccbb8b55458a6d96b3b16ff2abc138b88c.diff

LOG: [mlir][arith] Add `index_cast` and `index_castui` support to WIE

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D138225

Added: 
    

Modified: 
    mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
    mlir/test/Dialect/Arith/emulate-wide-int.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
index 50069ae2b855e..ad42001996ec4 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
@@ -598,6 +598,86 @@ struct ConvertMaxMin final : OpConversionPattern<SourceOp> {
   }
 };
 
+// Convert IndexCast ops
+//===----------------------------------------------------------------------===//
+
+/// Returns true iff the type is `index` or `vector<...index>`.
+static bool isIndexOrIndexVector(Type type) {
+  if (type.isa<IndexType>())
+    return true;
+
+  if (auto vectorTy = type.dyn_cast<VectorType>())
+    if (vectorTy.getElementType().isa<IndexType>())
+      return true;
+
+  return false;
+}
+
+template <typename CastOp>
+struct ConvertIndexCastIntToIndex final : OpConversionPattern<CastOp> {
+  using OpConversionPattern<CastOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Type resultType = op.getType();
+    if (!isIndexOrIndexVector(resultType))
+      return failure();
+
+    Location loc = op.getLoc();
+    Type inType = op.getIn().getType();
+    auto newInTy = this->getTypeConverter()
+                       ->convertType(inType)
+                       .template dyn_cast_or_null<VectorType>();
+    if (!newInTy)
+      return rewriter.notifyMatchFailure(
+          loc, llvm::formatv("unsupported type: {0}", inType));
+
+    // Discard the high half of the input truncating the original value.
+    Value extracted = extractLastDimSlice(rewriter, loc, adaptor.getIn(), 0);
+    extracted = dropTrailingX1Dim(rewriter, loc, extracted);
+    rewriter.replaceOpWithNewOp<CastOp>(op, resultType, extracted);
+    return success();
+  }
+};
+
+template <typename CastOp, typename ExtensionOp>
+struct ConvertIndexCastIndexToInt final : OpConversionPattern<CastOp> {
+  using OpConversionPattern<CastOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Type inType = op.getIn().getType();
+    if (!isIndexOrIndexVector(inType))
+      return failure();
+
+    Location loc = op.getLoc();
+    auto *typeConverter =
+        this->template getTypeConverter<arith::WideIntEmulationConverter>();
+
+    Type resultType = op.getType();
+    auto newTy = typeConverter->convertType(resultType)
+                     .template dyn_cast_or_null<VectorType>();
+    if (!newTy)
+      return rewriter.notifyMatchFailure(
+          loc, llvm::formatv("unsupported type: {0}", resultType));
+
+    // Emit an index cast over the matching narrow type.
+    Type narrowTy =
+        rewriter.getIntegerType(typeConverter->getMaxTargetIntBitWidth());
+    if (auto vecTy = resultType.dyn_cast<VectorType>())
+      narrowTy = VectorType::get(vecTy.getShape(), narrowTy);
+
+    // Sign or zero-extend the result. Let the matching conversion pattern
+    // legalize the extension op.
+    Value underlyingVal =
+        rewriter.create<CastOp>(loc, narrowTy, adaptor.getIn());
+    rewriter.replaceOpWithNewOp<ExtensionOp>(op, resultType, underlyingVal);
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // ConvertSelect
 //===----------------------------------------------------------------------===//
@@ -841,8 +921,7 @@ struct ConvertShRSI final : OpConversionPattern<arith::ShRSIOp> {
     // Rewrite this as an bitwise or of `arith.shrui` and sign extension bits.
     // Perform as many ops over the narrow integer type as possible and let the
     // other emulation patterns convert the rest.
-    Value elemZero =
-        createScalarOrSplatConstant(rewriter, loc, narrowTy, 0);
+    Value elemZero = createScalarOrSplatConstant(rewriter, loc, narrowTy, 0);
     Value signBit = rewriter.create<arith::CmpIOp>(
         loc, arith::CmpIPredicate::slt, lhsElem1, elemZero);
     signBit = dropTrailingX1Dim(rewriter, loc, signBit);
@@ -862,7 +941,8 @@ struct ConvertShRSI final : OpConversionPattern<arith::ShRSIOp> {
         rewriter.create<arith::ShLIOp>(loc, allSign, numNonSignExtBits);
 
     // Use original arguments to create the right shift.
-    Value shrui = rewriter.create<arith::ShRUIOp>(loc, op.getLhs(), op.getRhs());
+    Value shrui =
+        rewriter.create<arith::ShRUIOp>(loc, op.getLhs(), op.getRhs());
     Value shrsi = rewriter.create<arith::OrIOp>(loc, shrui, signBits);
 
     // Handle shifting by zero. This is necessary when the `signBits` shift is
@@ -870,7 +950,8 @@ struct ConvertShRSI final : OpConversionPattern<arith::ShRSIOp> {
     Value isNoop = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
                                                   rhsElem0, elemZero);
     isNoop = dropTrailingX1Dim(rewriter, loc, isNoop);
-    rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNoop, op.getLhs(), shrsi);
+    rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNoop, op.getLhs(),
+                                                 shrsi);
 
     return success();
   }
@@ -1045,6 +1126,11 @@ void arith::populateArithWideIntEmulationPatterns(
       ConvertBitwiseBinary<arith::AndIOp>, ConvertBitwiseBinary<arith::OrIOp>,
       ConvertBitwiseBinary<arith::XOrIOp>,
       // Extension and truncation ops.
-      ConvertExtSI, ConvertExtUI, ConvertTruncI>(typeConverter,
-                                                 patterns.getContext());
+      ConvertExtSI, ConvertExtUI, ConvertTruncI,
+      // Cast ops.
+      ConvertIndexCastIntToIndex<arith::IndexCastOp>,
+      ConvertIndexCastIntToIndex<arith::IndexCastUIOp>,
+      ConvertIndexCastIndexToInt<arith::IndexCastOp, arith::ExtSIOp>,
+      ConvertIndexCastIndexToInt<arith::IndexCastUIOp, arith::ExtUIOp>>(
+      typeConverter, patterns.getContext());
 }

diff  --git a/mlir/test/Dialect/Arith/emulate-wide-int.mlir b/mlir/test/Dialect/Arith/emulate-wide-int.mlir
index bfef4162ef8c5..0f85e7a859386 100644
--- a/mlir/test/Dialect/Arith/emulate-wide-int.mlir
+++ b/mlir/test/Dialect/Arith/emulate-wide-int.mlir
@@ -365,6 +365,102 @@ func.func @extui_vector(%a : vector<3xi16>) -> vector<3xi64> {
     return %r : vector<3xi64>
 }
 
+// CHECK-LABEL: func @index_cast_int_to_index_scalar
+// CHECK-SAME:    ([[ARG:%.+]]: vector<2xi32>) -> index
+// CHECK-NEXT:    [[EXT:%.+]]  = vector.extract [[ARG]][0] : vector<2xi32>
+// CHECK-NEXT:    [[RES:%.+]]  = arith.index_cast [[EXT]] : i32 to index
+// CHECK-NEXT:    return [[RES]] : index
+func.func @index_cast_int_to_index_scalar(%a : i64) -> index {
+    %r = arith.index_cast %a : i64 to index
+    return %r : index
+}
+
+// CHECK-LABEL: func @index_cast_int_to_index_vector
+// CHECK-SAME:    ([[ARG:%.+]]: vector<3x2xi32>) -> vector<3xindex>
+// CHECK-NEXT:    [[EXT:%.+]]   = vector.extract_strided_slice [[ARG]] {offsets = [0, 0], sizes = [3, 1], strides = [1, 1]} : vector<3x2xi32> to vector<3x1xi32>
+// CHECK-NEXT:    [[SHAPE:%.+]] = vector.shape_cast [[EXT]] : vector<3x1xi32> to vector<3xi32>
+// CHECK-NEXT:    [[RES:%.+]]   = arith.index_cast [[SHAPE]] : vector<3xi32> to vector<3xindex>
+// CHECK-NEXT:    return [[RES]] : vector<3xindex>
+func.func @index_cast_int_to_index_vector(%a : vector<3xi64>) -> vector<3xindex> {
+    %r = arith.index_cast %a : vector<3xi64> to vector<3xindex>
+    return %r : vector<3xindex>
+}
+
+// CHECK-LABEL: func @index_castui_int_to_index_scalar
+// CHECK-SAME:    ([[ARG:%.+]]: vector<2xi32>) -> index
+// CHECK-NEXT:    [[EXT:%.+]]  = vector.extract [[ARG]][0] : vector<2xi32>
+// CHECK-NEXT:    [[RES:%.+]]  = arith.index_castui [[EXT]] : i32 to index
+// CHECK-NEXT:    return [[RES]] : index
+func.func @index_castui_int_to_index_scalar(%a : i64) -> index {
+    %r = arith.index_castui %a : i64 to index
+    return %r : index
+}
+
+// CHECK-LABEL: func @index_castui_int_to_index_vector
+// CHECK-SAME:    ([[ARG:%.+]]: vector<3x2xi32>) -> vector<3xindex>
+// CHECK-NEXT:    [[EXT:%.+]]   = vector.extract_strided_slice [[ARG]] {offsets = [0, 0], sizes = [3, 1], strides = [1, 1]} : vector<3x2xi32> to vector<3x1xi32>
+// CHECK-NEXT:    [[SHAPE:%.+]] = vector.shape_cast [[EXT]] : vector<3x1xi32> to vector<3xi32>
+// CHECK-NEXT:    [[RES:%.+]]   = arith.index_castui [[SHAPE]] : vector<3xi32> to vector<3xindex>
+// CHECK-NEXT:    return [[RES]] : vector<3xindex>
+func.func @index_castui_int_to_index_vector(%a : vector<3xi64>) -> vector<3xindex> {
+    %r = arith.index_castui %a : vector<3xi64> to vector<3xindex>
+    return %r : vector<3xindex>
+}
+
+// CHECK-LABEL: func @index_cast_index_to_int_scalar
+// CHECK-SAME:    ([[ARG:%.+]]: index) -> vector<2xi32>
+// CHECK-NEXT:    [[CAST:%.+]]  = arith.index_cast [[ARG]] : index to i32
+// CHECK-NEXT:    [[C0I32:%.+]] = arith.constant 0 : i32
+// CHECK-NEXT:    [[NEG:%.+]]   = arith.cmpi slt, [[CAST]], [[C0I32]] : i32
+// CHECK-NEXT:    [[EXT:%.+]]   = arith.extsi [[NEG]] : i1 to i32
+// CHECK-NEXT:    [[VZ:%.+]]    = arith.constant dense<0> : vector<2xi32>
+// CHECK-NEXT:    [[INS0:%.+]]  = vector.insert [[CAST]], [[VZ]] [0] : i32 into vector<2xi32>
+// CHECK-NEXT:    [[INS1:%.+]]  = vector.insert [[EXT]], [[INS0]] [1] : i32 into vector<2xi32>
+// CHECK-NEXT:    return [[INS1]] : vector<2xi32>
+func.func @index_cast_index_to_int_scalar(%a : index) -> i64 {
+    %r = arith.index_cast %a : index to i64
+    return %r : i64
+}
+
+// CHECK-LABEL: func @index_cast_index_to_int_vector
+// CHECK-SAME:    ([[ARG:%.+]]: vector<3xindex>) -> vector<3x2xi32>
+// CHECK-NEXT:    arith.index_cast [[ARG]] : vector<3xindex> to vector<3xi32>
+// CHECK-NEXT:    vector.shape_cast
+// CHECK-NEXT:    arith.constant dense<0> : vector<3x1xi32>
+// CHECK-NEXT:    arith.cmpi slt
+// CHECK-NEXT:    arith.extsi
+// CHECK-NEXT:    arith.constant dense<0> : vector<3x2xi32>
+// CHECK-NEXT:    vector.insert_strided_slice
+// CHECK-NEXT:    vector.insert_strided_slice
+// CHECK-NEXT:    return {{%.+}} : vector<3x2xi32>
+func.func @index_cast_index_to_int_vector(%a : vector<3xindex>) -> vector<3xi64> {
+    %r = arith.index_cast %a : vector<3xindex> to vector<3xi64>
+    return %r : vector<3xi64>
+}
+
+// CHECK-LABEL: func @index_castui_index_to_int_scalar
+// CHECK-SAME:    ([[ARG:%.+]]: index) -> vector<2xi32>
+// CHECK-NEXT:    [[CAST:%.+]]  = arith.index_castui [[ARG]] : index to i32
+// CHECK-NEXT:    [[VZ:%.+]]    = arith.constant dense<0> : vector<2xi32>
+// CHECK-NEXT:    [[RES:%.+]]   = vector.insert [[CAST]], [[VZ]] [0] : i32 into vector<2xi32>
+// CHECK-NEXT:    return [[RES]] : vector<2xi32>
+func.func @index_castui_index_to_int_scalar(%a : index) -> i64 {
+    %r = arith.index_castui %a : index to i64
+    return %r : i64
+}
+
+// CHECK-LABEL: func @index_castui_index_to_int_vector
+// CHECK-SAME:    ([[ARG:%.+]]: vector<3xindex>) -> vector<3x2xi32>
+// CHECK-NEXT:    [[CAST:%.+]]  = arith.index_castui [[ARG]] : vector<3xindex> to vector<3xi32>
+// CHECK-NEXT:    [[SHAPE:%.+]] = vector.shape_cast [[CAST]] : vector<3xi32> to vector<3x1xi32>
+// CHECK-NEXT:    [[CST:%.+]]   = arith.constant dense<0> : vector<3x2xi32>
+// CHECK-NEXT:    [[RES:%.+]]   = vector.insert_strided_slice [[SHAPE]], [[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<3x1xi32> into vector<3x2xi32>
+// CHECK-NEXT:    return [[RES]] : vector<3x2xi32>
+func.func @index_castui_index_to_int_vector(%a : vector<3xindex>) -> vector<3xi64> {
+    %r = arith.index_castui %a : vector<3xindex> to vector<3xi64>
+    return %r : vector<3xi64>
+}
+
 // CHECK-LABEL: func @trunci_scalar1
 // CHECK-SAME:    ([[ARG:%.+]]: vector<2xi32>) -> i32
 // CHECK-NEXT:    [[EXT:%.+]] = vector.extract [[ARG]][0] : vector<2xi32>


        


More information about the Mlir-commits mailing list