[Mlir-commits] [mlir] 36e68c1 - [mlir][spirv] Add support for converting vector.shuffle

Lei Zhang llvmlistbot at llvm.org
Fri Feb 4 13:31:34 PST 2022


Author: Lei Zhang
Date: 2022-02-04T16:31:11-05:00
New Revision: 36e68c11ad6f20d058dfad001470bede8d9731dd

URL: https://github.com/llvm/llvm-project/commit/36e68c11ad6f20d058dfad001470bede8d9731dd
DIFF: https://github.com/llvm/llvm-project/commit/36e68c11ad6f20d058dfad001470bede8d9731dd.diff

LOG: [mlir][spirv] Add support for converting vector.shuffle

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D119030

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
    mlir/test/Conversion/VectorToSPIRV/simple.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index b40052df603b9..4061f57909cac 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -18,7 +18,9 @@
 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/STLExtras.h"
 #include <numeric>
 
 using namespace mlir;
@@ -264,6 +266,43 @@ class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> {
   }
 };
 
+struct VectorShuffleOpConvert final
+    : public OpConversionPattern<vector::ShuffleOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto oldResultType = shuffleOp.getVectorType();
+    if (!spirv::CompositeType::isValid(oldResultType))
+      return failure();
+    auto newResultType = getTypeConverter()->convertType(oldResultType);
+
+    auto oldSourceType = shuffleOp.getV1VectorType();
+    if (oldSourceType.getNumElements() > 1) {
+      SmallVector<int32_t, 4> components = llvm::to_vector<4>(
+          llvm::map_range(shuffleOp.mask(), [](Attribute attr) -> int32_t {
+            return attr.cast<IntegerAttr>().getValue().getZExtValue();
+          }));
+      rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
+          shuffleOp, newResultType, adaptor.v1(), adaptor.v2(),
+          rewriter.getI32ArrayAttr(components));
+      return success();
+    }
+
+    SmallVector<Value, 2> oldOperands = {adaptor.v1(), adaptor.v2()};
+    SmallVector<Value, 4> newOperands;
+    newOperands.reserve(oldResultType.getNumElements());
+    for (const APInt &i : shuffleOp.mask().getAsValueRange<IntegerAttr>()) {
+      newOperands.push_back(oldOperands[i.getZExtValue()]);
+    }
+    rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
+        shuffleOp, newResultType, newOperands);
+
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
@@ -272,6 +311,6 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                VectorExtractElementOpConvert, VectorExtractOpConvert,
                VectorExtractStridedSliceOpConvert, VectorFmaOpConvert,
                VectorInsertElementOpConvert, VectorInsertOpConvert,
-               VectorInsertStridedSliceOpConvert, VectorSplatPattern>(
-      typeConverter, patterns.getContext());
+               VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
+               VectorSplatPattern>(typeConverter, patterns.getContext());
 }

diff  --git a/mlir/test/Conversion/VectorToSPIRV/simple.mlir b/mlir/test/Conversion/VectorToSPIRV/simple.mlir
index d93aa21f7dc36..949e59951a4f4 100644
--- a/mlir/test/Conversion/VectorToSPIRV/simple.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/simple.mlir
@@ -179,3 +179,34 @@ func @splat(%f : f32) -> vector<4xf32> {
   %splat = vector.splat %f : vector<4xf32>
   return %splat : vector<4xf32>
 }
+
+// -----
+
+// CHECK-LABEL:  func @shuffle
+//  CHECK-SAME:  %[[ARG0:.+]]: vector<1xf32>, %[[ARG1:.+]]: vector<1xf32>
+//       CHECK:    %[[V0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
+//       CHECK:    %[[V1:.+]] = builtin.unrealized_conversion_cast %[[ARG1]]
+//       CHECK:    spv.CompositeConstruct %[[V0]], %[[V1]], %[[V1]], %[[V0]] : vector<4xf32>
+func @shuffle(%v0 : vector<1xf32>, %v1: vector<1xf32>) -> vector<4xf32> {
+  %shuffle = vector.shuffle %v0, %v1 [0, 1, 1, 0] : vector<1xf32>, vector<1xf32>
+  return %shuffle : vector<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL:  func @shuffle
+//  CHECK-SAME:  %[[V0:.+]]: vector<3xf32>, %[[V1:.+]]: vector<3xf32>
+//       CHECK:    spv.VectorShuffle [3 : i32, 2 : i32, 5 : i32, 1 : i32] %[[V0]] : vector<3xf32>, %[[V1]] : vector<3xf32> -> vector<4xf32>
+func @shuffle(%v0 : vector<3xf32>, %v1: vector<3xf32>) -> vector<4xf32> {
+  %shuffle = vector.shuffle %v0, %v1 [3, 2, 5, 1] : vector<3xf32>, vector<3xf32>
+  return %shuffle : vector<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL:  func @shuffle
+func @shuffle(%v0 : vector<2x16xf32>, %v1: vector<1x16xf32>) -> vector<3x16xf32> {
+  // CHECK: vector.shuffle
+  %shuffle = vector.shuffle %v0, %v1 [0, 1, 2] : vector<2x16xf32>, vector<1x16xf32>
+  return %shuffle : vector<3x16xf32>
+}


        


More information about the Mlir-commits mailing list