[Mlir-commits] [mlir] 3483fc5 - [mlir][SPIRVToLLVM] Add shufflevector conversion
Weiwei Li
llvmlistbot at llvm.org
Mon Nov 1 08:05:58 PDT 2021
Author: Weiwei Li
Date: 2021-11-01T23:05:37+08:00
New Revision: 3483fc5a31179123ffd11494488074683706a492
URL: https://github.com/llvm/llvm-project/commit/3483fc5a31179123ffd11494488074683706a492
DIFF: https://github.com/llvm/llvm-project/commit/3483fc5a31179123ffd11494488074683706a492.diff
LOG: [mlir][SPIRVToLLVM] Add shufflevector conversion
Add the shufflevector conversion. It only handles the static, i.e., IntegerAttr, index.
Co-authored: Xinyi Liu <xyliuhelen at gmail.com>
Reviewed by: antiagainst
Differential revision: https://reviews.llvm.org/D112161
Added:
Modified:
mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index 51df6c763c3fe..b416c303ad51a 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -1364,6 +1364,65 @@ class ModuleConversionPattern : public SPIRVToLLVMConversion<spirv::ModuleOp> {
}
};
+//===----------------------------------------------------------------------===//
+// VectorShuffleOp conversion
+//===----------------------------------------------------------------------===//
+
+class VectorShufflePattern
+ : public SPIRVToLLVMConversion<spirv::VectorShuffleOp> {
+public:
+ using SPIRVToLLVMConversion<spirv::VectorShuffleOp>::SPIRVToLLVMConversion;
+ LogicalResult
+ matchAndRewrite(spirv::VectorShuffleOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ auto components = adaptor.components();
+ auto vector1 = adaptor.vector1();
+ auto vector2 = adaptor.vector2();
+ int vector1Size = vector1.getType().cast<VectorType>().getNumElements();
+ int vector2Size = vector2.getType().cast<VectorType>().getNumElements();
+ if (vector1Size == vector2Size) {
+ rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(op, vector1, vector2,
+ components);
+ return success();
+ }
+
+ auto dstType = typeConverter.convertType(op.getType());
+ auto scalarType = dstType.cast<VectorType>().getElementType();
+ auto componentsArray = components.getValue();
+ auto context = rewriter.getContext();
+ auto llvmI32Type = IntegerType::get(context, 32);
+ Value targetOp = rewriter.create<LLVM::UndefOp>(loc, dstType);
+ for (unsigned i = 0; i < componentsArray.size(); i++) {
+ if (componentsArray[i].isa<IntegerAttr>())
+ op.emitError("unable to support non-constant component");
+
+ int indexVal = componentsArray[i].cast<IntegerAttr>().getInt();
+ if (indexVal == -1)
+ continue;
+
+ int offsetVal = 0;
+ Value baseVector = vector1;
+ if (indexVal >= vector1Size) {
+ offsetVal = vector1Size;
+ baseVector = vector2;
+ }
+
+ Value dstIndex = rewriter.create<LLVM::ConstantOp>(
+ loc, llvmI32Type, rewriter.getIntegerAttr(rewriter.getI32Type(), i));
+ Value index = rewriter.create<LLVM::ConstantOp>(
+ loc, llvmI32Type,
+ rewriter.getIntegerAttr(rewriter.getI32Type(), indexVal - offsetVal));
+
+ auto extractOp = rewriter.create<LLVM::ExtractElementOp>(
+ loc, scalarType, baseVector, index);
+ targetOp = rewriter.create<LLVM::InsertElementOp>(loc, dstType, targetOp,
+ extractOp, dstIndex);
+ }
+ rewriter.replaceOp(op, targetOp);
+ return success();
+ }
+};
} // namespace
//===----------------------------------------------------------------------===//
@@ -1489,6 +1548,7 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
CompositeExtractPattern, CompositeInsertPattern,
DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>,
DirectConversionPattern<spirv::UndefOp, LLVM::UndefOp>,
+ VectorShufflePattern,
// Shift ops
ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>,
diff --git a/mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir
index c8528d062f7cb..38d55b9a659b8 100644
--- a/mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir
+++ b/mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir
@@ -58,6 +58,32 @@ spv.func @select_vector(%arg0: vector<2xi1>, %arg1: vector<2xi32>) "None" {
spv.Return
}
+//===----------------------------------------------------------------------===//
+// spv.VectorShuffle
+//===----------------------------------------------------------------------===//
+
+spv.func @vector_shuffle_same_size(%vector1: vector<2xf32>, %vector2: vector<2xf32>) -> vector<3xf32> "None" {
+ // CHECK: %[[res:.*]] = llvm.shufflevector {{.*}} [0 : i32, 2 : i32, -1 : i32] : vector<2xf32>, vector<2xf32>
+ // CHECK-NEXT: return %[[res]] : vector<3xf32>
+ %0 = spv.VectorShuffle [0: i32, 2: i32, 0xffffffff: i32] %vector1: vector<2xf32>, %vector2: vector<2xf32> -> vector<3xf32>
+ spv.ReturnValue %0: vector<3xf32>
+}
+
+spv.func @vector_shuffle_
diff erent_size(%vector1: vector<3xf32>, %vector2: vector<2xf32>) -> vector<3xf32> "None" {
+ // CHECK: %[[UNDEF:.*]] = llvm.mlir.undef : vector<3xf32>
+ // CHECK-NEXT: %[[C0_0:.*]] = llvm.mlir.constant(0 : i32) : i32
+ // CHECK-NEXT: %[[C0_1:.*]] = llvm.mlir.constant(0 : i32) : i32
+ // CHECK-NEXT: %[[EXT0:.*]] = llvm.extractelement %arg0[%[[C0_1]] : i32] : vector<3xf32>
+ // CHECK-NEXT: %[[INSERT0:.*]] = llvm.insertelement %[[EXT0]], %[[UNDEF]][%[[C0_0]] : i32] : vector<3xf32>
+ // CHECK-NEXT: %[[C1_0:.*]] = llvm.mlir.constant(1 : i32) : i32
+ // CHECK-NEXT: %[[C1_1:.*]] = llvm.mlir.constant(1 : i32) : i32
+ // CHECK-NEXT: %[[EXT1:.*]] = llvm.extractelement {{.*}}[%[[C1_1]] : i32] : vector<2xf32>
+ // CHECK-NEXT: %[[RES:.*]] = llvm.insertelement %[[EXT1]], %[[INSERT0]][%[[C1_0]] : i32] : vector<3xf32>
+ // CHECK-NEXT: llvm.return %[[RES]] : vector<3xf32>
+ %0 = spv.VectorShuffle [0: i32, 4: i32, 0xffffffff: i32] %vector1: vector<3xf32>, %vector2: vector<2xf32> -> vector<3xf32>
+ spv.ReturnValue %0: vector<3xf32>
+}
+
//===----------------------------------------------------------------------===//
// spv.EntryPoint and spv.ExecutionMode
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list