[Mlir-commits] [mlir] 3597727 - [MLIR] Support lowering n-D arith.index_cast to LLVM

Hanhan Wang llvmlistbot at llvm.org
Tue Aug 9 11:18:06 PDT 2022


Author: Jerry Wu
Date: 2022-08-09T11:17:52-07:00
New Revision: 3597727fa7ff104bb0f6e39a71900960b87144df

URL: https://github.com/llvm/llvm-project/commit/3597727fa7ff104bb0f6e39a71900960b87144df
DIFF: https://github.com/llvm/llvm-project/commit/3597727fa7ff104bb0f6e39a71900960b87144df.diff

LOG: [MLIR] Support lowering n-D arith.index_cast to LLVM

Previously we can only lower arith.index_cast with 1-D vectors to LLVM. This change added the support for n-D vectors.

Reviewed By: ftynse, hanchung

Differential Revision: https://reviews.llvm.org/D129907

Added: 
    

Modified: 
    mlir/lib/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.cpp
    mlir/test/Conversion/ArithmeticToLLVM/convert-nd-vector-to-llvmir.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.cpp b/mlir/lib/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.cpp
index 93090f7bb9bd5..5d378342b11a2 100644
--- a/mlir/lib/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.cpp
@@ -131,22 +131,48 @@ ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
 LogicalResult IndexCastOpLowering::matchAndRewrite(
     arith::IndexCastOp op, OpAdaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
-  auto targetType = typeConverter->convertType(op.getResult().getType());
+  auto resultType = op.getResult().getType();
   auto targetElementType =
-      typeConverter->convertType(getElementTypeOrSelf(op.getResult()))
-          .cast<IntegerType>();
+      typeConverter->convertType(getElementTypeOrSelf(resultType));
   auto sourceElementType =
-      getElementTypeOrSelf(adaptor.getIn()).cast<IntegerType>();
-  unsigned targetBits = targetElementType.getWidth();
-  unsigned sourceBits = sourceElementType.getWidth();
+      typeConverter->convertType(getElementTypeOrSelf(op.getIn()));
+  unsigned targetBits = targetElementType.getIntOrFloatBitWidth();
+  unsigned sourceBits = sourceElementType.getIntOrFloatBitWidth();
 
-  if (targetBits == sourceBits)
+  if (targetBits == sourceBits) {
     rewriter.replaceOp(op, adaptor.getIn());
-  else if (targetBits < sourceBits)
-    rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType, adaptor.getIn());
-  else
-    rewriter.replaceOpWithNewOp<LLVM::SExtOp>(op, targetType, adaptor.getIn());
-  return success();
+    return success();
+  }
+
+  // Handle the scalar and 1D vector cases.
+  auto operandType = adaptor.getIn().getType();
+  if (!operandType.isa<LLVM::LLVMArrayType>()) {
+    auto targetType = typeConverter->convertType(resultType);
+    if (targetBits < sourceBits)
+      rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType,
+                                                 adaptor.getIn());
+    else
+      rewriter.replaceOpWithNewOp<LLVM::SExtOp>(op, targetType,
+                                                adaptor.getIn());
+    return success();
+  }
+
+  auto vectorType = resultType.dyn_cast<VectorType>();
+  if (!vectorType)
+    return rewriter.notifyMatchFailure(op, "expected vector result type");
+
+  return LLVM::detail::handleMultidimensionalVectors(
+      op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
+      [&](Type llvm1DVectorTy, ValueRange operands) -> Value {
+        OpAdaptor adaptor(operands);
+        if (targetBits < sourceBits) {
+          return rewriter.create<LLVM::TruncOp>(op.getLoc(), llvm1DVectorTy,
+                                                adaptor.getIn());
+        }
+        return rewriter.create<LLVM::SExtOp>(op.getLoc(), llvm1DVectorTy,
+                                             adaptor.getIn());
+      },
+      rewriter);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Conversion/ArithmeticToLLVM/convert-nd-vector-to-llvmir.mlir b/mlir/test/Conversion/ArithmeticToLLVM/convert-nd-vector-to-llvmir.mlir
index 0df07b5d80532..7c219578baedd 100644
--- a/mlir/test/Conversion/ArithmeticToLLVM/convert-nd-vector-to-llvmir.mlir
+++ b/mlir/test/Conversion/ArithmeticToLLVM/convert-nd-vector-to-llvmir.mlir
@@ -210,3 +210,24 @@ func.func @select_2d(%arg0 : vector<4x3xi1>, %arg1 : vector<4x3xi32>, %arg2 : ve
   %0 = arith.select %arg0, %arg1, %arg2 : vector<4x3xi1>, vector<4x3xi32>
   func.return
 }
+
+// CHECK-LABEL: func @index_cast_2d(
+// CHECK-SAME:    %[[ARG0:.*]]: vector<1x2x3xi1>)
+func.func @index_cast_2d(%arg0: vector<1x2x3xi1>) {
+  // CHECK: %[[SRC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]]
+  // CHECK: %[[EXTRACT1:.*]] = llvm.extractvalue %[[SRC]][0, 0] : !llvm.array<1 x array<2 x vector<3xi1>>>
+  // CHECK: %[[SEXT1:.*]] = llvm.sext %[[EXTRACT1]] : vector<3xi1> to vector<3xi{{.*}}>
+  // CHECK: %[[INSERT1:.*]] = llvm.insertvalue %[[SEXT1]], %{{.*}}[0, 0] : !llvm.array<1 x array<2 x vector<3xi{{.*}}>>>
+  // CHECK: %[[EXTRACT2:.*]] = llvm.extractvalue %[[SRC]][0, 1] : !llvm.array<1 x array<2 x vector<3xi1>>>
+  // CHECK: %[[SEXT2:.*]] = llvm.sext %[[EXTRACT2]] : vector<3xi1> to vector<3xi{{.*}}>
+  // CHECK: %[[INSERT2:.*]] = llvm.insertvalue %[[SEXT2]], %[[INSERT1]][0, 1] : !llvm.array<1 x array<2 x vector<3xi{{.*}}>>>
+  %0 = arith.index_cast %arg0: vector<1x2x3xi1> to vector<1x2x3xindex>
+  // CHECK: %[[EXTRACT3:.*]] = llvm.extractvalue %[[INSERT2]][0, 0] : !llvm.array<1 x array<2 x vector<3xi{{.*}}>>>
+  // CHECK: %[[TRUNC1:.*]] = llvm.trunc %[[EXTRACT3]] : vector<3xi{{.*}}> to vector<3xi1>
+  // CHECK: %[[INSERT3:.*]] = llvm.insertvalue %[[TRUNC1]], %{{.*}}[0, 0] : !llvm.array<1 x array<2 x vector<3xi1>>>
+  // CHECK: %[[EXTRACT4:.*]] = llvm.extractvalue %[[INSERT2]][0, 1] : !llvm.array<1 x array<2 x vector<3xi{{.*}}>>>
+  // CHECK: %[[TRUNC2:.*]] = llvm.trunc %[[EXTRACT4]] : vector<3xi{{.*}}> to vector<3xi1>
+  // CHECK: %[[INSERT4:.*]] = llvm.insertvalue %[[TRUNC2]], %[[INSERT3]][0, 1] : !llvm.array<1 x array<2 x vector<3xi1>>>
+  %1 = arith.index_cast %0: vector<1x2x3xindex> to vector<1x2x3xi1>
+  return
+}


        


More information about the Mlir-commits mailing list