[Mlir-commits] [mlir] 9912bed - [mlir][linalg] Remove RangeOp and RangeType.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Dec 14 23:19:47 PST 2021


Author: gysit
Date: 2021-12-15T07:19:10Z
New Revision: 9912bed7306f0157b6425f233eec56d04bb536cf

URL: https://github.com/llvm/llvm-project/commit/9912bed7306f0157b6425f233eec56d04bb536cf
DIFF: https://github.com/llvm/llvm-project/commit/9912bed7306f0157b6425f233eec56d04bb536cf.diff

LOG: [mlir][linalg] Remove RangeOp and RangeType.

Remove the RangeOp and the RangeType that are not actively used anymore. After removing RangeType, the LinalgTypes header only includes the generated dialect header.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D115727

Added: 
    

Modified: 
    mlir/docs/Dialects/Linalg/_index.md
    mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
    mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
    mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h
    mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
    mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
    mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
    mlir/test/Dialect/Linalg/invalid.mlir
    mlir/test/Dialect/Linalg/roundtrip.mlir

Removed: 
    mlir/test/Dialect/Linalg/llvm.mlir


################################################################################
diff  --git a/mlir/docs/Dialects/Linalg/_index.md b/mlir/docs/Dialects/Linalg/_index.md
index 3c2742ac51f18..790f858dad262 100644
--- a/mlir/docs/Dialects/Linalg/_index.md
+++ b/mlir/docs/Dialects/Linalg/_index.md
@@ -520,7 +520,6 @@ generally alias the operand `view`. At the moment the existing ops are:
 * `memref.view`,
 * `memref.subview`,
 * `memref.transpose`.
-* `linalg.range`,
 * `linalg.slice`,
 * `linalg.reshape`,
 ```

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
index de3703b71acb0..de5bc6d33e678 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
@@ -58,8 +58,4 @@ def Linalg_Dialect : Dialect {
   }];
 }
 
-// Whether a type is a RangeType.
-def LinalgIsRangeTypePred : CPred<"$_self.isa<RangeType>()">;
-def Range : DialectType<Linalg_Dialect, LinalgIsRangeTypePred, "range">;
-
 #endif // LINALG_BASE

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index df39b311f410b..a5c7561981927 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -330,34 +330,6 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
   let hasFolder = 1;
 }
 
-def Linalg_RangeOp :
-    Linalg_Op<"range", [NoSideEffect]>,
-    Arguments<(ins Index:$min, Index:$max, Index:$step)>,
-    Results<(outs Range)> {
-  let summary = "Create a `range` type value, used to create `view`s";
-  let description = [{
-    The `linalg.range` op creates a `!linalg.range` from 3 values of type
-    `index` that represent the min, max and step values of the `range`. This
-    type does not pass function boundaries at the moment.
-
-    Example:
-
-    ```mlir
-    %3 = linalg.range %0:%1:%2 : !linalg.range
-    ````
-  }];
-  let builders = [
-    OpBuilder<(ins "Value":$min, "Value":$max, "Value":$step),
-    [{
-      auto rangeType = RangeType::get($_builder.getContext());
-      build($_builder, $_state, rangeType, min, max, step);
-    }]>];
-
-  // Fully specified by traits.
-  let verifier = ?;
-  let assemblyFormat = "$min `:` $max `:` $step attr-dict `:` type(results)";
-}
-
 def Linalg_YieldOp : Linalg_Op<"yield", [NoSideEffect, ReturnLike, Terminator]>,
     Arguments<(ins Variadic<AnyType>:$values)> {
   let summary = "Linalg yield operation";

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h
index 396cc3b591201..3c99ecb4dda1d 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h
@@ -22,27 +22,4 @@
 
 #include "mlir/Dialect/Linalg/IR/LinalgOpsDialect.h.inc"
 
-namespace mlir {
-class MLIRContext;
-
-namespace linalg {
-
-/// A RangeType represents a minimal range abstraction (min, max, step).
-/// It is constructed by calling the linalg.range op with three values index of
-/// index type:
-///
-/// ```mlir
-///    func @foo(%arg0 : index, %arg1 : index, %arg2 : index) {
-///      %0 = linalg.range %arg0:%arg1:%arg2 : !linalg.range
-///    }
-/// ```
-class RangeType : public Type::TypeBase<RangeType, Type, TypeStorage> {
-public:
-  // Used for generic hooks in TypeBase.
-  using Base::Base;
-};
-
-} // namespace linalg
-} // namespace mlir
-
 #endif // MLIR_DIALECT_LINALG_LINALGTYPES_H_

diff  --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
index 713890425acc6..478d73c07ac71 100644
--- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
+++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
@@ -52,48 +52,7 @@ static Type getPtrToElementType(T containerType, LLVMTypeConverter &lowering) {
       lowering.convertType(containerType.getElementType()));
 }
 
-/// Convert the given range descriptor type to the LLVMIR dialect.
-/// Range descriptor contains the range bounds and the step as 64-bit integers.
-///
-/// struct {
-///   int64_t min;
-///   int64_t max;
-///   int64_t step;
-/// };
-static Type convertRangeType(RangeType t, LLVMTypeConverter &converter) {
-  auto *context = t.getContext();
-  auto int64Ty = converter.convertType(IntegerType::get(context, 64));
-  return LLVMStructType::getLiteral(context, {int64Ty, int64Ty, int64Ty});
-}
-
 namespace {
-// RangeOp creates a new range descriptor.
-class RangeOpConversion : public ConvertOpToLLVMPattern<RangeOp> {
-public:
-  using ConvertOpToLLVMPattern<RangeOp>::ConvertOpToLLVMPattern;
-
-  LogicalResult
-  matchAndRewrite(RangeOp rangeOp, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    auto rangeDescriptorTy = convertRangeType(
-        rangeOp.getType().cast<RangeType>(), *getTypeConverter());
-
-    ImplicitLocOpBuilder b(rangeOp->getLoc(), rewriter);
-
-    // Fill in an aggregate value of the descriptor.
-    Value desc = b.create<LLVM::UndefOp>(rangeDescriptorTy);
-    desc = b.create<LLVM::InsertValueOp>(desc, adaptor.min(),
-                                         rewriter.getI64ArrayAttr(0));
-    desc = b.create<LLVM::InsertValueOp>(desc, adaptor.max(),
-                                         rewriter.getI64ArrayAttr(1));
-    desc = b.create<LLVM::InsertValueOp>(desc, adaptor.step(),
-                                         rewriter.getI64ArrayAttr(2));
-    rewriter.replaceOp(rangeOp, desc);
-    return success();
-  }
-};
-
-
 // YieldOp produces and LLVM::ReturnOp.
 class YieldOpConversion : public ConvertOpToLLVMPattern<linalg::YieldOp> {
 public:
@@ -111,11 +70,7 @@ class YieldOpConversion : public ConvertOpToLLVMPattern<linalg::YieldOp> {
 /// Populate the given list with patterns that convert from Linalg to LLVM.
 void mlir::populateLinalgToLLVMConversionPatterns(LLVMTypeConverter &converter,
                                                   RewritePatternSet &patterns) {
-  patterns.add<RangeOpConversion, YieldOpConversion>(converter);
-
-  // Populate the type conversions for the linalg types.
-  converter.addConversion(
-      [&](RangeType type) { return convertRangeType(type, converter); });
+  patterns.add<YieldOpConversion>(converter);
 }
 
 namespace {
@@ -135,7 +90,6 @@ void ConvertLinalgToLLVMPass::runOnOperation() {
   populateMemRefToLLVMConversionPatterns(converter, patterns);
 
   LLVMConversionTarget target(getContext());
-  target.addIllegalOp<RangeOp>();
   target.addLegalOp<ModuleOp>();
   if (failed(applyPartialConversion(module, target, std::move(patterns))))
     signalPassFailure();

diff  --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
index 6227d21521e41..e0f909d482973 100644
--- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
+++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
@@ -187,7 +187,7 @@ void ConvertLinalgToStandardPass::runOnOperation() {
   target.addLegalDialect<AffineDialect, arith::ArithmeticDialect,
                          memref::MemRefDialect, scf::SCFDialect,
                          StandardOpsDialect>();
-  target.addLegalOp<ModuleOp, FuncOp, ReturnOp, linalg::RangeOp>();
+  target.addLegalOp<ModuleOp, FuncOp, ReturnOp>();
   RewritePatternSet patterns(&getContext());
   populateLinalgToStandardConversionPatterns(patterns);
   if (failed(applyFullConversion(module, target, std::move(patterns))))

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
index 4d004352c6839..c06c0f4a76c21 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
@@ -106,7 +106,6 @@ void addNamedOpBuilders(
 }
 
 void mlir::linalg::LinalgDialect::initialize() {
-  addTypes<RangeType>();
   addOperations<
 #define GET_OP_LIST
 #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
@@ -125,29 +124,6 @@ void mlir::linalg::LinalgDialect::initialize() {
   addInterfaces<LinalgInlinerInterface>();
 }
 
-Type mlir::linalg::LinalgDialect::parseType(DialectAsmParser &parser) const {
-  // Parse the main keyword for the type.
-  StringRef keyword;
-  if (parser.parseKeyword(&keyword))
-    return Type();
-  MLIRContext *context = getContext();
-
-  // Handle 'range' types.
-  if (keyword == "range")
-    return RangeType::get(context);
-
-  parser.emitError(parser.getNameLoc(), "unknown Linalg type: " + keyword);
-  return Type();
-}
-
-/// RangeType prints as just "range".
-static void print(RangeType rt, DialectAsmPrinter &os) { os << "range"; }
-
-void mlir::linalg::LinalgDialect::printType(Type type,
-                                            DialectAsmPrinter &os) const {
-  print(type.cast<RangeType>(), os);
-}
-
 LogicalResult LinalgDialect::verifyOperationAttribute(Operation *op,
                                                       NamedAttribute attr) {
   using comprehensive_bufferize::BufferizableOpInterface;

diff  --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 5c7549cd45ad4..ab51fe5445b51 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -298,16 +298,6 @@ func @generic(%arg0: memref<?x?xi4>) {
 //
 // // -----
 
-// expected-error @+1 {{unknown Linalg type}}
-!invalid_type = type !linalg.unknown
-
-// -----
-
-// expected-error @+1 {{expected valid keyword}}
-!invalid_type = type !linalg<"?">
-
-// -----
-
 func @named_ops(%a3: memref<?x?x?xf32>, %b3: memref<?x?xf32>, %c3: memref<?x?x?xf32>) {
   // expected-error @+1 {{expected operand rank (2) to match the result rank of indexing_map #1 (3)}}
   linalg.batch_matmul ins(%a3, %b3: memref<?x?x?xf32>, memref<?x?xf32>)

diff  --git a/mlir/test/Dialect/Linalg/llvm.mlir b/mlir/test/Dialect/Linalg/llvm.mlir
deleted file mode 100644
index f6ab826ae151e..0000000000000
--- a/mlir/test/Dialect/Linalg/llvm.mlir
+++ /dev/null
@@ -1,15 +0,0 @@
-// RUN: mlir-opt %s -convert-linalg-to-llvm | FileCheck %s
-
-func @range(%arg0: index) {
-  %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
-  %R = linalg.range %c0:%arg0:%c1 : !linalg.range
-  return
-}
-// CHECK-LABEL: func @range
-//       CHECK:   arith.constant 0 : index
-//       CHECK:   arith.constant 1 : index
-//       CHECK:   llvm.mlir.undef : !llvm.struct<(i64, i64, i64)>
-//       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(i64, i64, i64)>
-//       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(i64, i64, i64)>
-//       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(i64, i64, i64)>

diff  --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index cac1d4448e530..f9559cbd5b96b 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -86,20 +86,10 @@ func @pad_to_static_size(%arg0: tensor<?x?xf32>, %ub0: index, %ub1: index,
 
 // -----
 
-func @range(%arg0: index, %arg1: index, %arg2: index) {
-  %0 = linalg.range %arg0:%arg1:%arg2 : !linalg.range
-  return
-}
-// CHECK-LABEL: func @range(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
-//  CHECK-NEXT:  linalg.range %{{.*}} : %{{.*}} : %{{.*}} : !linalg.range
-
-// -----
-
-func @views(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) {
+func @views(%arg0: index) {
   %c0 = arith.constant 0 : index
   %0 = arith.muli %arg0, %arg0 : index
   %1 = memref.alloc (%0) : memref<?xi8>
-  %2 = linalg.range %arg0:%arg1:%arg2 : !linalg.range
   %3 = memref.view %1[%c0][%arg0, %arg0] : memref<?xi8> to memref<?x?xf32>
   %4 = memref.view %1[%c0][%arg0, %arg0] : memref<?xi8> to memref<?x?xvector<4x4xf32>>
   memref.dealloc %1 : memref<?xi8>
@@ -108,7 +98,6 @@ func @views(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index
 // CHECK-LABEL: func @views
 //  CHECK:  arith.muli %{{.*}}, %{{.*}} : index
 //  CHECK-NEXT:  memref.alloc(%{{.*}}) : memref<?xi8>
-//  CHECK-NEXT:  range
 //  CHECK-NEXT:  memref.view %{{.*}}[%{{.*}}][%{{.*}}] :
 //  CHECK-SAME:     memref<?xi8> to memref<?x?xf32>
 //  CHECK-NEXT:  memref.view %{{.*}}[%{{.*}}][%{{.*}}] :


        


More information about the Mlir-commits mailing list