[Mlir-commits] [mlir] fb2c4d5 - [mlir] [VectorOps] Implement vector.constant_mask lowering to LLVM IR

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue May 12 19:44:38 PDT 2020


Author: aartbik
Date: 2020-05-12T19:44:23-07:00
New Revision: fb2c4d50f1426356d0aa3a9eaa1dab50b25c9082

URL: https://github.com/llvm/llvm-project/commit/fb2c4d50f1426356d0aa3a9eaa1dab50b25c9082
DIFF: https://github.com/llvm/llvm-project/commit/fb2c4d50f1426356d0aa3a9eaa1dab50b25c9082.diff

LOG: [mlir] [VectorOps] Implement vector.constant_mask lowering to LLVM IR

Summary:
Makes this operation runnable on CPU by generating MLIR instructions
that are eventually folded into an LLVM IR constant for the mask.

Reviewers: nicolasvasilache, ftynse, reidtatge, bkramer, andydavis1

Reviewed By: nicolasvasilache, ftynse, andydavis1

Subscribers: mehdi_amini, rriddle, jpienaar, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, stephenneuendorffer, Joonsoo, grosul1, frgossen, Kayjukh, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D79815

Added: 
    mlir/test/Target/vector-to-llvm-ir.mlir

Modified: 
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
    mlir/test/Dialect/Vector/vector-contract-transforms.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 6e3e681ad815..ef536e2194c1 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -1187,6 +1187,55 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
   }
 };
 
+/// Progressive lowering of ConstantMaskOp.
+/// One:
+///   %x = vector.constant_mask_op [a,b]
+/// is replaced by:
+///   %z = zero-result
+///   %l = vector.constant_mask_op [b]
+///   %4 = vector.insert %l, %z[0]
+///   ..
+///   %x = vector.insert %l, %..[a-1]
+/// which will be folded at LLVM IR level.
+class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
+public:
+  using OpRewritePattern<vector::ConstantMaskOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ConstantMaskOp op,
+                                PatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    auto dstType = op.getResult().getType().cast<VectorType>();
+    auto eltType = dstType.getElementType();
+    auto dimSizes = op.mask_dim_sizes();
+    int64_t rank = dimSizes.size();
+    int64_t trueDim = dimSizes[0].cast<IntegerAttr>().getInt();
+
+    Value trueVal;
+    if (rank == 1) {
+      trueVal = rewriter.create<ConstantOp>(
+          loc, eltType, rewriter.getIntegerAttr(eltType, 1));
+    } else {
+      VectorType lowType =
+          VectorType::get(dstType.getShape().drop_front(), eltType);
+      SmallVector<int64_t, 4> newDimSizes;
+      for (int64_t r = 1; r < rank; r++)
+        newDimSizes.push_back(dimSizes[r].cast<IntegerAttr>().getInt());
+      trueVal = rewriter.create<vector::ConstantMaskOp>(
+          loc, lowType, rewriter.getI64ArrayAttr(newDimSizes));
+    }
+
+    Value result = rewriter.create<ConstantOp>(loc, dstType,
+                                               rewriter.getZeroAttr(dstType));
+    for (int64_t d = 0; d < trueDim; d++) {
+      auto pos = rewriter.getI64ArrayAttr(d);
+      result =
+          rewriter.create<vector::InsertOp>(loc, dstType, trueVal, result, pos);
+    }
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
 /// Progressive lowering of ContractionOp.
 /// One:
 ///   %x = vector.contract with at least one free/batch dimension
@@ -1609,6 +1658,7 @@ void mlir::vector::populateVectorContractLoweringPatterns(
     VectorTransformsOptions parameters) {
   patterns.insert<ShapeCastOp2DDownCastRewritePattern,
                   ShapeCastOp2DUpCastRewritePattern, BroadcastOpLowering,
-                  TransposeOpLowering, OuterProductOpLowering>(context);
+                  TransposeOpLowering, OuterProductOpLowering,
+                  ConstantMaskOpLowering>(context);
   patterns.insert<ContractionOpLowering>(parameters, context);
 }

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 9e0b9464689c..1c23072b6109 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -917,3 +917,20 @@ func @transfer_read_1d_non_zero_addrspace(%A : memref<?xf32, 3>, %base: index) -
 //  CHECK-SAME: (!llvm<"float addrspace(3)*">, !llvm.i64) -> !llvm<"float addrspace(3)*">
 //       CHECK: %[[vecPtr_b:.*]] = llvm.addrspacecast %[[gep_b]] :
 //  CHECK-SAME: !llvm<"float addrspace(3)*"> to !llvm<"<17 x float>*">
+
+func @genbool_1d() -> vector<8xi1> {
+  %0 = vector.constant_mask [4] : vector<8xi1>
+  return %0 : vector<8xi1>
+}
+// CHECK-LABEL: func @genbool_1d
+// CHECK: %[[T0:.*]] = llvm.mlir.constant(1 : i1) : !llvm.i1
+// CHECK: %[[T1:.*]] = llvm.mlir.constant(dense<false> : vector<8xi1>) : !llvm<"<8 x i1>">
+// CHECK: %[[T2:.*]] = llvm.mlir.constant(0 : i64) : !llvm.i64
+// CHECK: %[[T3:.*]] = llvm.insertelement %[[T0]], %[[T1]][%[[T2]] : !llvm.i64] : !llvm<"<8 x i1>">
+// CHECK: %[[T4:.*]] = llvm.mlir.constant(1 : i64) : !llvm.i64
+// CHECK: %[[T5:.*]] = llvm.insertelement %[[T0]], %[[T3]][%[[T4]] : !llvm.i64] : !llvm<"<8 x i1>">
+// CHECK: %[[T6:.*]] = llvm.mlir.constant(2 : i64) : !llvm.i64
+// CHECK: %[[T7:.*]] = llvm.insertelement %[[T0]], %[[T5]][%[[T6]] : !llvm.i64] : !llvm<"<8 x i1>">
+// CHECK: %[[T8:.*]] = llvm.mlir.constant(3 : i64) : !llvm.i64
+// CHECK: %[[T9:.*]] = llvm.insertelement %[[T0]], %[[T7]][%[[T8]] : !llvm.i64] : !llvm<"<8 x i1>">
+// CHECK: llvm.return %9 : !llvm<"<8 x i1>">

diff  --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
index 563c9f76d9b2..a4d4fe58e8d5 100644
--- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
@@ -559,3 +559,49 @@ func @broadcast_stretch_in_middle(%arg0: vector<4x1x2xf32>) -> vector<4x3x2xf32>
   %0 = vector.broadcast %arg0 : vector<4x1x2xf32> to vector<4x3x2xf32>
   return %0 : vector<4x3x2xf32>
 }
+
+// CHECK-LABEL: func @genbool_1d
+// CHECK: %[[TT:.*]] = constant 1 : i1
+// CHECK: %[[C1:.*]] = constant dense<false> : vector<8xi1>
+// CHECK: %[[T0.*]] = vector.insert %[[TT]], %[[C1]] [0] : i1 into vector<8xi1>
+// CHECK: %[[T1.*]] = vector.insert %[[TT]], %[[T0]] [1] : i1 into vector<8xi1>
+// CHECK: %[[T2.*]] = vector.insert %[[TT]], %[[T1]] [2] : i1 into vector<8xi1>
+// CHECK: %[[T3.*]] = vector.insert %[[TT]], %[[T2]] [3] : i1 into vector<8xi1>
+// CHECK: return %[[T3]] : vector<8xi1>
+
+func @genbool_1d() -> vector<8xi1> {
+  %0 = vector.constant_mask [4] : vector<8xi1>
+  return %0 : vector<8xi1>
+}
+
+// CHECK-LABEL: func @genbool_2d
+// CHECK: %[[TT:.*]] = constant 1 : i1
+// CHECK: %[[C1:.*]] = constant dense<false> : vector<4xi1>
+// CHECK: %[[C2:.*]] = constant dense<false> : vector<4x4xi1>
+// CHECK: %[[T0:.*]] = vector.insert %[[TT]], %[[C1]] [0] : i1 into vector<4xi1>
+// CHECK: %[[T1:.*]] = vector.insert %[[TT]], %[[T0]] [1] : i1 into vector<4xi1>
+// CHECK: %[[T2:.*]] = vector.insert %[[T1]], %[[C2]] [0] : vector<4xi1> into vector<4x4xi1>
+// CHECK: %[[T3:.*]] = vector.insert %[[T1]], %[[T2]] [1] : vector<4xi1> into vector<4x4xi1>
+// CHECK: return %[[T3]] : vector<4x4xi1>
+
+func @genbool_2d() -> vector<4x4xi1> {
+  %v = vector.constant_mask [2, 2] : vector<4x4xi1>
+  return %v: vector<4x4xi1>
+}
+
+// CHECK-LABEL: func @genbool_3d
+// CHECK: %[[Tt:.*]] = constant 1 : i1
+// CHECK: %[[C1:.*]] = constant dense<false> : vector<4xi1>
+// CHECK: %[[C2:.*]] = constant dense<false> : vector<3x4xi1>
+// CHECK: %[[C3:.*]] = constant dense<false> : vector<2x3x4xi1>
+// CHECK: %[[T0:.*]] = vector.insert %[[TT]], %[[C1]] [0] : i1 into vector<4xi1>
+// CHECK: %[[T1:.*]] = vector.insert %[[TT]], %[[T0]] [1] : i1 into vector<4xi1>
+// CHECK: %[[T2:.*]] = vector.insert %[[TT]], %[[T1]] [2] : i1 into vector<4xi1>
+// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C2]] [0] : vector<4xi1> into vector<3x4xi1>
+// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[C3]] [0] : vector<3x4xi1> into vector<2x3x4xi1>
+// CHECK: return %[[T4]] : vector<2x3x4xi1>
+
+func @genbool_3d() -> vector<2x3x4xi1> {
+  %v = vector.constant_mask [1, 1, 3] : vector<2x3x4xi1>
+  return %v: vector<2x3x4xi1>
+}

diff  --git a/mlir/test/Target/vector-to-llvm-ir.mlir b/mlir/test/Target/vector-to-llvm-ir.mlir
new file mode 100644
index 000000000000..354f70f3f4b9
--- /dev/null
+++ b/mlir/test/Target/vector-to-llvm-ir.mlir
@@ -0,0 +1,23 @@
+// RUN: mlir-opt %s -convert-vector-to-llvm | mlir-translate -mlir-to-llvmir | FileCheck %s
+
+func @genbool_1d() -> vector<8xi1> {
+  %0 = vector.constant_mask [4] : vector<8xi1>
+  return %0 : vector<8xi1>
+}
+// CHECK-LABEL: @genbool_1d()
+// CHECK-NEXT: ret <8 x i1> <i1 true, i1 true, i1 true, i1 true, i1 false, i1 false, i1 false, i1 false>
+
+func @genbool_2d() -> vector<4x4xi1> {
+  %v = vector.constant_mask [2, 2] : vector<4x4xi1>
+  return %v: vector<4x4xi1>
+}
+// CHECK-LABEL: @genbool_2d()
+// CHECK-NEXT: ret [4 x <4 x i1>] [<4 x i1> <i1 true, i1 true, i1 false, i1 false>, <4 x i1> <i1 true, i1 true, i1 false, i1 false>, <4 x i1> zeroinitializer, <4 x i1> zeroinitializer]
+
+func @genbool_3d() -> vector<2x3x4xi1> {
+  %v = vector.constant_mask [1, 1, 3] : vector<2x3x4xi1>
+  return %v: vector<2x3x4xi1>
+}
+// CHECK-LABEL: @genbool_3d()
+// CHECK-NEXT: ret [2 x [3 x <4 x i1>]] {{\[+}}3 x <4 x i1>] [<4 x i1> <i1 true, i1 true, i1 true, i1 false>, <4 x i1> zeroinitializer, <4 x i1> zeroinitializer], [3 x <4 x i1>] zeroinitializer]
+// note: awkward syntax to match [[


        


More information about the Mlir-commits mailing list