[Mlir-commits] [mlir] 3597727 - [MLIR] Support lowering n-D arith.index_cast to LLVM
Hanhan Wang
llvmlistbot at llvm.org
Tue Aug 9 11:18:06 PDT 2022
Author: Jerry Wu
Date: 2022-08-09T11:17:52-07:00
New Revision: 3597727fa7ff104bb0f6e39a71900960b87144df
URL: https://github.com/llvm/llvm-project/commit/3597727fa7ff104bb0f6e39a71900960b87144df
DIFF: https://github.com/llvm/llvm-project/commit/3597727fa7ff104bb0f6e39a71900960b87144df.diff
LOG: [MLIR] Support lowering n-D arith.index_cast to LLVM
Previously we can only lower arith.index_cast with 1-D vectors to LLVM. This change added the support for n-D vectors.
Reviewed By: ftynse, hanchung
Differential Revision: https://reviews.llvm.org/D129907
Added:
Modified:
mlir/lib/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.cpp
mlir/test/Conversion/ArithmeticToLLVM/convert-nd-vector-to-llvmir.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.cpp b/mlir/lib/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.cpp
index 93090f7bb9bd5..5d378342b11a2 100644
--- a/mlir/lib/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.cpp
@@ -131,22 +131,48 @@ ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
LogicalResult IndexCastOpLowering::matchAndRewrite(
arith::IndexCastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
- auto targetType = typeConverter->convertType(op.getResult().getType());
+ auto resultType = op.getResult().getType();
auto targetElementType =
- typeConverter->convertType(getElementTypeOrSelf(op.getResult()))
- .cast<IntegerType>();
+ typeConverter->convertType(getElementTypeOrSelf(resultType));
auto sourceElementType =
- getElementTypeOrSelf(adaptor.getIn()).cast<IntegerType>();
- unsigned targetBits = targetElementType.getWidth();
- unsigned sourceBits = sourceElementType.getWidth();
+ typeConverter->convertType(getElementTypeOrSelf(op.getIn()));
+ unsigned targetBits = targetElementType.getIntOrFloatBitWidth();
+ unsigned sourceBits = sourceElementType.getIntOrFloatBitWidth();
- if (targetBits == sourceBits)
+ if (targetBits == sourceBits) {
rewriter.replaceOp(op, adaptor.getIn());
- else if (targetBits < sourceBits)
- rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType, adaptor.getIn());
- else
- rewriter.replaceOpWithNewOp<LLVM::SExtOp>(op, targetType, adaptor.getIn());
- return success();
+ return success();
+ }
+
+ // Handle the scalar and 1D vector cases.
+ auto operandType = adaptor.getIn().getType();
+ if (!operandType.isa<LLVM::LLVMArrayType>()) {
+ auto targetType = typeConverter->convertType(resultType);
+ if (targetBits < sourceBits)
+ rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType,
+ adaptor.getIn());
+ else
+ rewriter.replaceOpWithNewOp<LLVM::SExtOp>(op, targetType,
+ adaptor.getIn());
+ return success();
+ }
+
+ auto vectorType = resultType.dyn_cast<VectorType>();
+ if (!vectorType)
+ return rewriter.notifyMatchFailure(op, "expected vector result type");
+
+ return LLVM::detail::handleMultidimensionalVectors(
+ op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
+ [&](Type llvm1DVectorTy, ValueRange operands) -> Value {
+ OpAdaptor adaptor(operands);
+ if (targetBits < sourceBits) {
+ return rewriter.create<LLVM::TruncOp>(op.getLoc(), llvm1DVectorTy,
+ adaptor.getIn());
+ }
+ return rewriter.create<LLVM::SExtOp>(op.getLoc(), llvm1DVectorTy,
+ adaptor.getIn());
+ },
+ rewriter);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/ArithmeticToLLVM/convert-nd-vector-to-llvmir.mlir b/mlir/test/Conversion/ArithmeticToLLVM/convert-nd-vector-to-llvmir.mlir
index 0df07b5d80532..7c219578baedd 100644
--- a/mlir/test/Conversion/ArithmeticToLLVM/convert-nd-vector-to-llvmir.mlir
+++ b/mlir/test/Conversion/ArithmeticToLLVM/convert-nd-vector-to-llvmir.mlir
@@ -210,3 +210,24 @@ func.func @select_2d(%arg0 : vector<4x3xi1>, %arg1 : vector<4x3xi32>, %arg2 : ve
%0 = arith.select %arg0, %arg1, %arg2 : vector<4x3xi1>, vector<4x3xi32>
func.return
}
+
+// CHECK-LABEL: func @index_cast_2d(
+// CHECK-SAME: %[[ARG0:.*]]: vector<1x2x3xi1>)
+func.func @index_cast_2d(%arg0: vector<1x2x3xi1>) {
+ // CHECK: %[[SRC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]]
+ // CHECK: %[[EXTRACT1:.*]] = llvm.extractvalue %[[SRC]][0, 0] : !llvm.array<1 x array<2 x vector<3xi1>>>
+ // CHECK: %[[SEXT1:.*]] = llvm.sext %[[EXTRACT1]] : vector<3xi1> to vector<3xi{{.*}}>
+ // CHECK: %[[INSERT1:.*]] = llvm.insertvalue %[[SEXT1]], %{{.*}}[0, 0] : !llvm.array<1 x array<2 x vector<3xi{{.*}}>>>
+ // CHECK: %[[EXTRACT2:.*]] = llvm.extractvalue %[[SRC]][0, 1] : !llvm.array<1 x array<2 x vector<3xi1>>>
+ // CHECK: %[[SEXT2:.*]] = llvm.sext %[[EXTRACT2]] : vector<3xi1> to vector<3xi{{.*}}>
+ // CHECK: %[[INSERT2:.*]] = llvm.insertvalue %[[SEXT2]], %[[INSERT1]][0, 1] : !llvm.array<1 x array<2 x vector<3xi{{.*}}>>>
+ %0 = arith.index_cast %arg0: vector<1x2x3xi1> to vector<1x2x3xindex>
+ // CHECK: %[[EXTRACT3:.*]] = llvm.extractvalue %[[INSERT2]][0, 0] : !llvm.array<1 x array<2 x vector<3xi{{.*}}>>>
+ // CHECK: %[[TRUNC1:.*]] = llvm.trunc %[[EXTRACT3]] : vector<3xi{{.*}}> to vector<3xi1>
+ // CHECK: %[[INSERT3:.*]] = llvm.insertvalue %[[TRUNC1]], %{{.*}}[0, 0] : !llvm.array<1 x array<2 x vector<3xi1>>>
+ // CHECK: %[[EXTRACT4:.*]] = llvm.extractvalue %[[INSERT2]][0, 1] : !llvm.array<1 x array<2 x vector<3xi{{.*}}>>>
+ // CHECK: %[[TRUNC2:.*]] = llvm.trunc %[[EXTRACT4]] : vector<3xi{{.*}}> to vector<3xi1>
+ // CHECK: %[[INSERT4:.*]] = llvm.insertvalue %[[TRUNC2]], %[[INSERT3]][0, 1] : !llvm.array<1 x array<2 x vector<3xi1>>>
+ %1 = arith.index_cast %0: vector<1x2x3xindex> to vector<1x2x3xi1>
+ return
+}
More information about the Mlir-commits
mailing list