[Mlir-commits] [mlir] dd1992e - Support lowering of index-cast on vector types.

Jacques Pienaar llvmlistbot at llvm.org
Tue Jun 15 12:51:43 PDT 2021


Author: Arpith C. Jacob
Date: 2021-06-15T12:51:30-07:00
New Revision: dd1992efd3f1ebbaddc77edafcf17b967cafc1d9

URL: https://github.com/llvm/llvm-project/commit/dd1992efd3f1ebbaddc77edafcf17b967cafc1d9
DIFF: https://github.com/llvm/llvm-project/commit/dd1992efd3f1ebbaddc77edafcf17b967cafc1d9.diff

LOG: Support lowering of index-cast on vector types.

The index cast operation accepts vector types. Implement its lowering in this patch.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D104280

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
    mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index cfcda1f24214..7216e3d2ed5c 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -1204,9 +1204,10 @@ def FPTruncOp : ArithmeticCastOp<"fptrunc">, Arguments<(ins AnyType:$in)> {
 def IndexCastOp : ArithmeticCastOp<"index_cast">, Arguments<(ins AnyType:$in)> {
   let summary = "cast between index and integer types";
   let description = [{
-    Casts between integer scalars and 'index' scalars. Index is an integer of
-    platform-specific bit width. If casting to a wider integer, the value is
-    sign-extended. If casting to a narrower integer, the value is truncated.
+    Casts between scalar or vector integers and corresponding 'index' scalar or
+    vectors. Index is an integer of platform-specific bit width. If casting to
+    a wider integer, the value is sign-extended. If casting to a narrower
+    integer, the value is truncated.
   }];
 
   let hasFolder = 1;

diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 61074382470e..8e808d75e205 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -3104,11 +3104,15 @@ struct IndexCastOpLowering : public ConvertOpToLLVMPattern<IndexCastOp> {
     IndexCastOpAdaptor transformed(operands);
 
     auto targetType =
-        typeConverter->convertType(indexCastOp.getResult().getType())
+        typeConverter->convertType(indexCastOp.getResult().getType());
+    auto targetElementType =
+        typeConverter
+            ->convertType(getElementTypeOrSelf(indexCastOp.getResult()))
             .cast<IntegerType>();
-    auto sourceType = transformed.in().getType().cast<IntegerType>();
-    unsigned targetBits = targetType.getWidth();
-    unsigned sourceBits = sourceType.getWidth();
+    auto sourceElementType =
+        getElementTypeOrSelf(transformed.in()).cast<IntegerType>();
+    unsigned targetBits = targetElementType.getWidth();
+    unsigned sourceBits = sourceElementType.getWidth();
 
     if (targetBits == sourceBits)
       rewriter.replaceOp(indexCastOp, transformed.in());

diff  --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
index b4942de204ab..78a03372904d 100644
--- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
@@ -520,6 +520,15 @@ func @index_cast(%arg0: index, %arg1: i1) {
   return
 }
 
+// CHECK-LABEL: @vector_index_cast
+func @vector_index_cast(%arg0: vector<2xindex>, %arg1: vector<2xi1>) {
+// CHECK-NEXT: = llvm.trunc %{{.*}} : vector<2xi{{.*}}> to vector<2xi1>
+  %0 = index_cast %arg0: vector<2xindex> to vector<2xi1>
+// CHECK-NEXT: = llvm.sext %{{.*}} : vector<2xi1> to vector<2xi{{.*}}>
+  %1 = index_cast %arg1: vector<2xi1> to vector<2xindex>
+  return
+}
+
 // Checking conversion of signed integer types to floating point.
 // CHECK-LABEL: @sitofp
 func @sitofp(%arg0 : i32, %arg1 : i64) {


        


More information about the Mlir-commits mailing list