[Mlir-commits] [mlir] 6a3c69e - [mlir][spirv] Infer converted type of scf.for from the init value

Thomas Raoux llvmlistbot at llvm.org
Tue Aug 25 23:36:11 PDT 2020


Author: Thomas Raoux
Date: 2020-08-25T23:35:01-07:00
New Revision: 6a3c69e918b13482f2f8492ddd3a79ccdcb70f76

URL: https://github.com/llvm/llvm-project/commit/6a3c69e918b13482f2f8492ddd3a79ccdcb70f76
DIFF: https://github.com/llvm/llvm-project/commit/6a3c69e918b13482f2f8492ddd3a79ccdcb70f76.diff

LOG: [mlir][spirv] Infer converted type of scf.for from the init value

Instead of using the TypeConverter infer the value of the alloca created based
on the init value. This will allow some ambiguous types like multidimensional
vectors to be converted correctly.

Differential Revision: https://reviews.llvm.org/D86582

Added: 
    

Modified: 
    mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
index b8eb87c80368..9c5f8393e8fd 100644
--- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
+++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
@@ -91,13 +91,13 @@ template <typename ScfOp, typename OpTy>
 static void replaceSCFOutputValue(ScfOp scfOp, OpTy newOp,
                                   SPIRVTypeConverter &typeConverter,
                                   ConversionPatternRewriter &rewriter,
-                                  ScfToSPIRVContextImpl *scfToSPIRVContext) {
+                                  ScfToSPIRVContextImpl *scfToSPIRVContext,
+                                  ArrayRef<Type> returnTypes) {
 
   Location loc = scfOp.getLoc();
   auto &allocas = scfToSPIRVContext->outputVars[newOp];
   SmallVector<Value, 8> resultValue;
-  for (Value result : scfOp.results()) {
-    auto convertedType = typeConverter.convertType(result.getType());
+  for (Type convertedType : returnTypes) {
     auto pointerType =
         spirv::PointerType::get(convertedType, spirv::StorageClass::Function);
     rewriter.setInsertionPoint(newOp);
@@ -185,8 +185,15 @@ ForOpConversion::matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands,
       loc, newIndVar.getType(), newIndVar, forOperands.step());
   rewriter.create<spirv::BranchOp>(loc, header, updatedIndVar);
 
+  // Infer the return types from the init operands. Vector type may get
+  // converted to CooperativeMatrix or to Vector type, to avoid having complex
+  // extra logic to figure out the right type we just infer it from the Init
+  // operands.
+  SmallVector<Type, 8> initTypes;
+  for (auto arg : forOperands.initArgs())
+    initTypes.push_back(arg.getType());
   replaceSCFOutputValue(forOp, loopOp, typeConverter, rewriter,
-                        scfToSPIRVContext);
+                        scfToSPIRVContext, initTypes);
   return success();
 }
 
@@ -238,8 +245,13 @@ IfOpConversion::matchAndRewrite(scf::IfOp ifOp, ArrayRef<Value> operands,
                                               thenBlock, ArrayRef<Value>(),
                                               elseBlock, ArrayRef<Value>());
 
+  SmallVector<Type, 8> returnTypes;
+  for (auto result : ifOp.results()) {
+    auto convertedType = typeConverter.convertType(result.getType());
+    returnTypes.push_back(convertedType);
+  }
   replaceSCFOutputValue(ifOp, selectionOp, typeConverter, rewriter,
-                        scfToSPIRVContext);
+                        scfToSPIRVContext, returnTypes);
   return success();
 }
 


        


More information about the Mlir-commits mailing list