CoroutinesCodeGenerator.cpp
Go to the documentation of this file.
1/******************************************************************************
2 *
3 * C++ Insights, copyright (C) by Andreas Fertig
4 * Distributed under an MIT license. See LICENSE for details
5 *
6 ****************************************************************************/
7
8#include <iterator>
9#include <optional>
10#include <vector>
11#include "ASTHelpers.h"
12#include "CodeGenerator.h"
13#include "DPrint.h"
14#include "Insights.h"
15#include "InsightsHelpers.h"
16#include "NumberIterator.h"
17
18#include <algorithm>
19//-----------------------------------------------------------------------------
20
21namespace ranges = std::ranges;
22//-----------------------------------------------------------------------------
23
24namespace clang::insights {
25
26constexpr std::string_view CORO_FRAME_NAME{"__f"sv};
27const std::string CORO_FRAME_ACCESS{StrCat(CORO_FRAME_NAME, "->"sv)};
29const std::string SUSPEND_INDEX_NAME{BuildInternalVarName("suspend_index"sv)};
30const std::string INITIAL_AWAIT_SUSPEND_CALLED_NAME{BuildInternalVarName("initial_await_suspend_called"sv)};
31const std::string RESUME_LABEL_PREFIX{BuildInternalVarName("resume"sv)};
32const std::string FINAL_SUSPEND_NAME{BuildInternalVarName("final_suspend"sv)};
33//-----------------------------------------------------------------------------
34
35using namespace asthelpers;
36//-----------------------------------------------------------------------------
37
38QualType CoroutinesCodeGenerator::GetFramePointerType() const
39{
40 return Ptr(GetFrameType());
41}
42//-----------------------------------------------------------------------------
43
45{
46 RETURN_IF(not(mASTData.mFrameType and mASTData.mDoInsertInDtor));
47
48 mASTData.mFrameType->completeDefinition();
49
51
52 // Using the "normal" CodeGenerator here as this is only about inserting the made up coroutine-frame.
53 CodeGeneratorVariant codeGenerator{ofm};
54 codeGenerator->InsertArg(mASTData.mFrameType);
55
56 // Insert the made-up struct before the function declaration
57 mOutputFormatHelper.InsertAt(mPosBeforeFunc, ofm);
58}
59//-----------------------------------------------------------------------------
60
61static FieldDecl* AddField(CoroutineASTData& astData, std::string_view name, QualType type)
62{
63 if(nullptr == astData.mFrameType) {
64 return nullptr;
65 }
66
67 auto* fieldDecl = mkFieldDecl(astData.mFrameType, name, type);
68
69 astData.mFrameType->addDecl(fieldDecl);
70
71 return fieldDecl;
72}
73//-----------------------------------------------------------------------------
74
75FieldDecl* CoroutinesCodeGenerator::AddField(std::string_view name, QualType type)
76{
77 return ::clang::insights::AddField(mASTData, name, type);
78}
79//-----------------------------------------------------------------------------
80
81static auto* CreateCoroFunctionDecl(std::string funcName, QualType type)
82{
83 params_vector params{{CORO_FRAME_NAME, type}};
84 const std::string coroFsmName{BuildInternalVarName(funcName)};
85
86 return Function(coroFsmName, VoidTy(), params);
87}
88//-----------------------------------------------------------------------------
89
90static void SetFunctionBody(FunctionDecl* fd, StmtsContainer& bodyStmts)
91{
92 fd->setBody(mkCompoundStmt(bodyStmts));
93}
94//-----------------------------------------------------------------------------
95
96static std::string BuildSuspendVarName(const OpaqueValueExpr* stmt)
97{
99 MakeLineColumnName(GetGlobalAST().getSourceManager(), stmt->getSourceExpr()->getBeginLoc(), "suspend_"sv));
100}
101//-----------------------------------------------------------------------------
102
103/// \brief Find a \c SuspendsExpr's in a coroutine body statement for early transformation.
104///
105/// Traverse the whole CoroutineBodyStmt to find all appearing \c VarDecl. These need to be rerouted to the
106/// coroutine frame and hence prefixed by something like __f->. For that reason we only look for \c VarDecls
107/// directly appearing in the body, \c CallExpr will be skipped.
108class CoroutineASTTransformer : public StmtVisitor<CoroutineASTTransformer>
109{
110 StmtsContainer mBodyStmts{};
111 Stmt* mPrevStmt{}; // used to insert the suspendexpr
112 CoroutineASTData& mASTData;
113 Stmt* mStaged{};
114 bool mSkip{};
115 size_t& mSuspendsCount;
116 llvm::DenseMap<VarDecl*, MemberExpr*> mVarNamePrefix{};
117
118public:
120 size_t& suspendsCounter,
121 Stmt* stmt,
122 llvm::DenseMap<VarDecl*, MemberExpr*> varNamePrefix,
123 Stmt* prev = nullptr)
124 : mPrevStmt{prev}
125 , mASTData{coroutineASTData}
126 , mSuspendsCount{suspendsCounter}
127 , mVarNamePrefix{varNamePrefix}
128 {
129 if(nullptr == mPrevStmt) {
130 mPrevStmt = stmt;
131 }
132
133 Visit(stmt);
134 }
135
136 void Visit(Stmt* stmt)
137 {
138 if(stmt) {
140 }
141 }
142
143 void VisitCompoundStmt(CompoundStmt* stmt)
144 {
145 for(auto* child : stmt->body()) {
146 mStaged = child;
147 Visit(child);
148
149 if(not mSkip) {
150 mBodyStmts.Add(child);
151
152 if(const auto* coret = dyn_cast_or_null<CoreturnStmt>(child);
153 coret and (coret->getOperand() == nullptr)) {
154 mBodyStmts.Add(Goto(FINAL_SUSPEND_NAME));
155 }
156 }
157
158 mSkip = false;
159 }
160
161 auto* comp = mkCompoundStmt(mBodyStmts);
162
163 ReplaceNode(mPrevStmt, stmt, comp);
164
165 mBodyStmts.clear();
166 }
167
168 void VisitSwitchStmt(SwitchStmt* stmt)
169 {
170 Visit(stmt->getCond());
171
172 CoroutineASTTransformer{mASTData, mSuspendsCount, stmt->getBody(), mVarNamePrefix, stmt};
173 }
174
175 void VisitDoStmt(DoStmt* stmt)
176 {
177 Visit(stmt->getCond());
178
179 CoroutineASTTransformer{mASTData, mSuspendsCount, stmt->getBody(), mVarNamePrefix, stmt};
180 }
181
182 void VisitWhileStmt(WhileStmt* stmt)
183 {
184 Visit(stmt->getCond());
185
186 CoroutineASTTransformer{mASTData, mSuspendsCount, stmt->getBody(), mVarNamePrefix, stmt};
187 }
188
189 void VisitIfStmt(IfStmt* stmt)
190 {
191 Visit(stmt->getCond());
192
193 CoroutineASTTransformer{mASTData, mSuspendsCount, stmt->getThen(), mVarNamePrefix, stmt};
194
195 CoroutineASTTransformer{mASTData, mSuspendsCount, stmt->getElse(), mVarNamePrefix, stmt};
196 }
197
198 void VisitForStmt(ForStmt* stmt)
199 {
200 // technically because of the init the entire for loop should be put into a dedicated scope
201 Visit(stmt->getInit());
202
203 // Special case. A VarDecl in init will be added to the body of the function and the actual init is left
204 // untouched. Work some magic to put it in the right place.
205 if(mSkip) {
206 auto* oldInit = stmt->getInit();
207 auto* newInit = mBodyStmts.mStmts.back();
208 mBodyStmts.mStmts.pop_back();
209
210 ReplaceNode(stmt, oldInit, newInit);
211
212 mSkip = false;
213 }
214
215 Visit(stmt->getCond());
216
217 Visit(stmt->getInc());
218
219 CoroutineASTTransformer{mASTData, mSuspendsCount, stmt->getBody(), mVarNamePrefix, stmt};
220
221 mSkip = false;
222 }
223
224 void VisitCXXForRangeStmt(CXXForRangeStmt* stmt)
225 {
226 Visit(stmt->getRangeStmt());
227
228 // ignoring the loop variable should be fine.
229
230 CoroutineASTTransformer{mASTData, mSuspendsCount, stmt->getBody(), mVarNamePrefix, stmt};
231 }
232
233 void VisitDeclRefExpr(DeclRefExpr* stmt)
234 {
235 if(auto* vd = dyn_cast_or_null<VarDecl>(stmt->getDecl())) {
236 RETURN_IF(not vd->isLocalVarDeclOrParm() or vd->isStaticLocal() or not Contains(mVarNamePrefix, vd));
237
238 auto* memberExpr = mVarNamePrefix[vd];
239
240 ReplaceNode(mPrevStmt, stmt, memberExpr);
241 }
242 }
243
244 void VisitDeclStmt(DeclStmt* stmt)
245 {
246 for(auto* decl : stmt->decls()) {
247 if(auto* varDecl = dyn_cast_or_null<VarDecl>(decl)) {
248 if(varDecl->isStaticLocal()) {
249 continue;
250 }
251
252 // add this point a placement-new would be appropriate for at least some cases.
253
254 auto* field = AddField(mASTData, GetName(*varDecl), varDecl->getType());
255 auto* me = AccessMember(mASTData.mFrameAccessDeclRef, field);
256 auto* assign = Assign(me, field, varDecl->getInit());
257
258 mVarNamePrefix.insert(std::make_pair(varDecl, me));
259
260 Visit(varDecl->getInit());
261
262 mSkip = true;
263 mBodyStmts.Add(assign);
264
265 } else if(const auto* recordDecl = dyn_cast_or_null<CXXRecordDecl>(decl)) {
266 mASTData.mFrameType->addDecl(const_cast<CXXRecordDecl*>(recordDecl));
267 }
268 }
269 }
270
271 void VisitCXXThisExpr(CXXThisExpr* stmt)
272 {
273 auto* fieldDecl = mkFieldDecl(mASTData.mFrameType, kwInternalThis, stmt->getType());
274 auto* indirectThisMemberExpr = AccessMember(mASTData.mFrameAccessDeclRef, fieldDecl);
275
276 ReplaceNode(mPrevStmt, stmt, indirectThisMemberExpr);
277
278 if(0 == mASTData.mThisExprs.size()) {
279 mASTData.mThisExprs.push_back(stmt);
280 }
281 }
282
283 void VisitCallExpr(CallExpr* stmt)
284 {
285 auto* tmp = mPrevStmt;
286 mPrevStmt = stmt;
287
288 for(auto* arg : stmt->arguments()) {
289 Visit(arg);
290 }
291
292 mPrevStmt = tmp;
293 }
294
295 void VisitCXXMemberCallExpr(CXXMemberCallExpr* stmt)
296 {
297 auto* tmp = mPrevStmt;
298 mPrevStmt = stmt->getCallee();
299
300 Visit(stmt->getCallee());
301
302 mPrevStmt = tmp;
303
305 }
306
307 void VisitCoreturnStmt(CoreturnStmt* stmt)
308 {
309 Visit(stmt->getOperand());
310 Visit(stmt->getPromiseCall());
311 }
312
313 void VisitCoyieldExpr(CoyieldExpr* stmt)
314 {
315 ++mSuspendsCount;
316
317 if(isa<ExprWithCleanups>(mStaged)) {
318 mBodyStmts.Add(stmt);
319 mSkip = true;
320 }
321
322 Visit(stmt->getOperand());
323 }
324
325 void VisitCoawaitExpr(CoawaitExpr* stmt)
326 {
327 ++mSuspendsCount;
328
329 if(const bool returnsVoid{stmt->getResumeExpr()->getType()->isVoidType()}; returnsVoid) {
330 Visit(stmt->getOperand());
331
332 // in the void return case there is nothing to do, because this expression (potentially) is not nested.
333 return;
334 }
335
336 mBodyStmts.Add(stmt);
337
338 // Note: Add the this pointer to the name isn't the best but s quick approach
339 const std::string name{StrCat(CORO_FRAME_ACCESS, BuildSuspendVarName(stmt->getOpaqueValue()), "_res"sv)};
340
341 auto* resultVar = Variable(name, stmt->getType());
342 auto* resultVarDeclRef = mkDeclRefExpr(resultVar);
343
344 ReplaceNode(mPrevStmt, stmt, resultVarDeclRef);
345
346 Visit(stmt->getCommonExpr());
347 Visit(stmt->getOperand());
348 Visit(stmt->getSuspendExpr());
349 Visit(stmt->getReadyExpr());
350 Visit(stmt->getResumeExpr());
351 }
352
353 void VisitCoroutineBodyStmt(CoroutineBodyStmt* stmt)
354 {
355 auto* varDecl = stmt->getPromiseDecl();
356
357 mASTData.mPromiseField = AddField(mASTData, GetName(*varDecl), varDecl->getType());
358 auto* me = AccessMember(mASTData.mFrameAccessDeclRef, mASTData.mPromiseField);
359
360 mVarNamePrefix.insert(std::make_pair(varDecl, me));
361
362 auto& ctx = GetGlobalAST();
363
364 // add the suspend index variable
365 mASTData.mSuspendIndexField = AddField(mASTData, SUSPEND_INDEX_NAME, ctx.IntTy);
367
368 // https://timsong-cpp.github.io/cppwp/n4861/dcl.fct.def.coroutine#5.3
372
373 for(auto* param : stmt->getParamMoves()) {
374 if(auto* declStmt = dyn_cast_or_null<DeclStmt>(param)) {
375 if(auto* varDecl2 = dyn_cast_or_null<VarDecl>(declStmt->getSingleDecl())) {
376 // For the captured parameters we need to find the ParmVarDecl instead of the newly created VarDecl
377 if(auto* declRef = FindDeclRef(varDecl2->getAnyInitializer())) {
378 auto* varDecl = dyn_cast<ParmVarDecl>(declRef->getDecl());
379
380 auto* field = AddField(mASTData, GetName(*varDecl), varDecl->getType());
381 auto* me = AccessMember(mASTData.mFrameAccessDeclRef, field);
382
383 mVarNamePrefix.insert(std::make_pair(const_cast<ParmVarDecl*>(varDecl), me));
384 }
385 }
386 }
387 }
388
389 Visit(stmt->getBody());
390
391 Visit(stmt->getReturnStmt());
392 Visit(stmt->getReturnValue());
393 Visit(stmt->getReturnValueInit());
394 Visit(stmt->getExceptionHandler());
395 Visit(stmt->getReturnStmtOnAllocFailure());
396 Visit(stmt->getFallthroughHandler());
397 Visit(stmt->getInitSuspendStmt());
398 Visit(stmt->getFinalSuspendStmt());
399 }
400
401 void VisitStmt(Stmt* stmt)
402 {
403 auto* tmp = mPrevStmt;
404 mPrevStmt = stmt;
405
406 for(auto* child : stmt->children()) {
407 Visit(child);
408 }
409
410 mPrevStmt = tmp;
411 }
412};
413//-----------------------------------------------------------------------------
414
415void CoroutinesCodeGenerator::InsertCoroutine(const FunctionDecl& fd, const CoroutineBodyStmt* stmt)
416{
418
419 auto& ctx = GetGlobalAST();
420
421 mFSMName = [&] {
422 OutputFormatHelper ofm{};
423 CodeGeneratorVariant codeGenerator{ofm};
424
425 // Coroutines can be templates and then we end up with the same FSM name but different template parameters.
426 // XXX: This will fail with NTTP's like 3.14
427 if(const auto* args = fd.getTemplateSpecializationArgs()) {
428 ofm.Append('_');
429
430 for(OnceFalse needsUnderscore{}; const auto& arg : args->asArray()) {
431 if(needsUnderscore) {
432 ofm.Append('_');
433 }
434
435 codeGenerator->InsertTemplateArg(arg);
436 }
437 }
438
439 auto str = std::move(ofm.GetString());
440 ReplaceAll(str, "<"sv, ""sv);
441 ReplaceAll(str, ":"sv, ""sv);
442 ReplaceAll(str, ">"sv, ""sv);
443
445
446 if(fd.isOverloadedOperator()) {
447 return StrCat(MakeLineColumnName(ctx.getSourceManager(), stmt->getBeginLoc(), "operator_"sv), str);
448 } else {
449 return StrCat(GetName(fd), str);
450 }
451 }();
452
453 mFrameName = BuildInternalVarName(StrCat(mFSMName, "Frame"sv));
454
455 // Insert a made up struct which holds the "captured" parameters stored in the coroutine frame
456 mASTData.mFrameType = Struct(mFrameName);
457 mASTData.mFrameAccessDeclRef = mkVarDeclRefExpr(CORO_FRAME_NAME, GetFrameType());
458
459 // The coroutine frame starts with two function pointers to the resume and destroy function. See:
460 // https://gcc.gnu.org/legacy-ml/gcc-patches/2020-01/msg01096.html:
461 // "The ABI mandates that pointers into the coroutine frame point to an area
462 // begining with two function pointers (to the resume and destroy functions
463 // described below); these are immediately followed by the "promise object"
464 // described in the standard."
465 //
466 // and
467 // https://llvm.org/docs/Coroutines.html#id72 "Coroutine Representation"
468 auto* resumeFnFd = Function(hlpResumeFn, VoidTy(), {{CORO_FRAME_NAME, GetFramePointerType()}});
469 auto resumeFnType = Ptr(resumeFnFd->getType());
470 mASTData.mResumeFnField = AddField(hlpResumeFn, resumeFnType);
471
472 auto* destroyFnFd = Function(hlpDestroyFn, VoidTy(), {{CORO_FRAME_NAME, GetFramePointerType()}});
473 auto destroyFnType = Ptr(destroyFnFd->getType());
474 mASTData.mDestroyFnField = AddField(hlpDestroyFn, destroyFnType);
475
476 // Allocated the made up frame
477 mOutputFormatHelper.AppendCommentNewLine("Allocate the frame including the promise"sv);
478 mOutputFormatHelper.AppendCommentNewLine("Note: The actual parameter new is __builtin_coro_size"sv);
479
480 auto* coroFrameVar = Variable(CORO_FRAME_NAME, GetFramePointerType());
481 auto* reicast = ReinterpretCast(GetFramePointerType(), stmt->getAllocate());
482
483 coroFrameVar->setInit(reicast);
484
485 InsertArg(coroFrameVar);
486
487 // P0057R8: [dcl.fct.def.coroutine] p8: get_return_object_on_allocation_failure indicates that new may return a
488 // nullptr. In this case return get_return_object_on_allocation_failure.
489 if(stmt->getReturnStmtOnAllocFailure()) {
490 auto* nptr = new(ctx) CXXNullPtrLiteralExpr({});
491
492 // Create an IfStmt.
493 StmtsContainer bodyStmts{stmt->getReturnStmtOnAllocFailure()};
494 auto* ifStmt = If(Equal(nptr, mASTData.mFrameAccessDeclRef), bodyStmts);
495
497 InsertArg(ifStmt);
498 }
499
501 mASTData, mSuspendsCounter, const_cast<CoroutineBodyStmt*>(stmt), llvm::DenseMap<VarDecl*, MemberExpr*>{}};
502
503 // set initial suspend count to zero.
504 auto* setSuspendIndexToZero = Assign(mASTData.mFrameAccessDeclRef, mASTData.mSuspendIndexField, Int32(0));
505 InsertArgWithNull(setSuspendIndexToZero);
506
507 // https://timsong-cpp.github.io/cppwp/n4861/dcl.fct.def.coroutine#5.3
508 auto* initializeInitialAwaitResume =
510 InsertArgWithNull(initializeInitialAwaitResume);
511
512 // Move the parameters first
513 for(auto* param : stmt->getParamMoves()) {
514 if(const auto* declStmt = dyn_cast_or_null<DeclStmt>(param)) {
515 if(const auto* varDecl = dyn_cast_or_null<VarDecl>(declStmt->getSingleDecl())) {
516 const auto varName = GetName(*varDecl);
517
519 varName,
520 " = "sv,
521 "std::forward<"sv,
522 GetName(varDecl->getType()),
523 ">("sv,
524 varName,
525 ");"sv);
526 }
527 }
528 }
529
530 // According to https://eel.is/c++draft/dcl.fct.def.coroutine#5.7 the promise_type constructor can have
531 // parameters. If so, they must be equal to the coroutines function parameters.
532 // The code here performs a _simple_ lookup for a matching ctor without using Clang's overload resolution.
533 ArrayRef<ParmVarDecl*> funParams = fd.parameters();
534 SmallVector<ParmVarDecl*, 16> funParamStorage{};
535 QualType cxxMethodType{};
536
537 if(const auto* cxxMethodDecl = dyn_cast_or_null<CXXMethodDecl>(&fd)) {
538 funParamStorage.reserve(funParams.size() + 1);
539
540 cxxMethodType = cxxMethodDecl->getFunctionObjectParameterType();
541
542 // In case we have a member function the first parameter is a reference to this. The following code injects
543 // this parameter.
544 funParamStorage.push_back(Parameter(&fd, CORO_FRAME_ACCESS_THIS, cxxMethodType));
545
546 ranges::copy(funParams, std::back_inserter(funParamStorage));
547
548 funParams = funParamStorage;
549 }
550
551 auto getNonRefType = [&](auto* var) -> QualType {
552 if(const auto* et = var->getType().getNonReferenceType()->template getAs<ElaboratedType>()) {
553 return et->getNamedType();
554 } else {
555 return QualType(var->getType().getNonReferenceType().getTypePtrOrNull(), 0);
556 }
557 };
558
559 SmallVector<Expr*, 16> exprs{};
560
561 for(auto* promiseTypeRecordDecl = mASTData.mPromiseField->getType()->getAsCXXRecordDecl();
562 auto* ctor : promiseTypeRecordDecl->ctors()) {
563
564 if(not ranges::equal(
565 ctor->parameters(), funParams, [&](auto& a, auto& b) { return getNonRefType(a) == getNonRefType(b); })) {
566 continue;
567 }
568
569 // In case of a promise ctor which takes this as the first argument, that parameter needs to be deferences,
570 // as it can only be taken as a reference.
571 OnceTrue derefFirstParam{};
572
573 if(not ctor->param_empty() and
574 (getNonRefType(ctor->getParamDecl(0)) == QualType(cxxMethodType.getTypePtrOrNull(), 0))) {
575 if(0 == mASTData.mThisExprs.size()) {
576 mASTData.mThisExprs.push_back(CXXThisExpr::Create(ctx, {}, Ptr(cxxMethodType), false));
577 }
578 } else {
579 (void)static_cast<bool>(derefFirstParam); // set it to false
580 }
581
582 for(const auto& fparam : funParams) {
583 if(derefFirstParam) {
584 exprs.push_back(Dref(mkDeclRefExpr(fparam)));
585
586 } else {
587 exprs.push_back(AccessMember(mASTData.mFrameAccessDeclRef, fparam));
588 }
589 }
590
591 if(funParams.size()) {
592 // The <new> header needs to be included.
594 }
595
596 break; // We've found what we were looking for
597 }
598
599 if(mASTData.mThisExprs.size()) {
601 }
602
603 // Now call the promise ctor, as it may access some of the parameters it comes at this point.
605 mOutputFormatHelper.AppendCommentNewLine("Construct the promise."sv);
606 auto* me = AccessMember(mASTData.mFrameAccessDeclRef, mASTData.mPromiseField);
607
608 auto* ctorArgs = new(ctx) InitListExpr{ctx, {}, exprs, {}};
609
610 CXXNewExpr* newFrame = New({AddrOf(me)}, ctorArgs, mASTData.mPromiseField->getType());
611
612 InsertArgWithNull(newFrame);
613
614 // Add parameters from the original function to the list
615
616 // P0057R8: [dcl.fct.def.coroutine] p5: before initial_suspend and at tops 1
617
618 // Make a call to the made up state machine function for the initial suspend
620
621 // [dcl.fct.def.coroutine]
622 mOutputFormatHelper.AppendCommentNewLine("Forward declare the resume and destroy function."sv);
623
624 auto* fsmFuncDecl = CreateCoroFunctionDecl(StrCat(mFSMName, "Resume"sv), GetFramePointerType());
625 InsertArg(fsmFuncDecl);
626 auto* deallocFuncDecl = CreateCoroFunctionDecl(StrCat(mFSMName, "Destroy"sv), GetFramePointerType());
627 InsertArg(deallocFuncDecl);
628
630
631 mOutputFormatHelper.AppendCommentNewLine("Assign the resume and destroy function pointers."sv);
632
633 auto* assignResumeFn = Assign(mASTData.mFrameAccessDeclRef, mASTData.mResumeFnField, Ref(fsmFuncDecl));
634 InsertArgWithNull(assignResumeFn);
635
636 auto* assignDestroyFn = Assign(mASTData.mFrameAccessDeclRef, mASTData.mDestroyFnField, Ref(deallocFuncDecl));
637 InsertArgWithNull(assignDestroyFn);
639
641 R"A(Call the made up function with the coroutine body for initial suspend.
642 This function will be called subsequently by coroutine_handle<>::resume()
643 which calls __builtin_coro_resume(__handle_))A"sv);
644
645 auto* callCoroFSM = Call(fsmFuncDecl, {mASTData.mFrameAccessDeclRef});
646 InsertArgWithNull(callCoroFSM);
647
650
651 InsertArg(stmt->getResultDecl());
652 InsertArg(stmt->getReturnStmt());
653
655
656 mOutputFormatHelper.CloseScope(OutputFormatHelper::NoNewLineBefore::Yes);
659
660 // add contents of the original function to the body of our made up function
661 StmtsContainer fsmFuncBodyStmts{stmt};
662
663 mOutputFormatHelper.AppendCommentNewLine("This function invoked by coroutine_handle<>::resume()"sv);
664 SetFunctionBody(fsmFuncDecl, fsmFuncBodyStmts);
665 InsertArg(fsmFuncDecl);
666
667 mASTData.mDoInsertInDtor = true; // As we have a coroutine insert the frame when this object goes out of scope.
668
669#if 0 // Preserve for later. Technically the destructor for the entire frame that's made up below takes care of
670 // everything.
671
672 // A destructor is only present, if they promise_type or one of its members is non-trivially destructible.
673 if(auto* dtor = mASTData.mPromiseField->getType()->getAsCXXRecordDecl()->getDestructor()) {
674 deallocFuncBodyStmts.Add(Comment("Deallocating the coroutine promise type"sv));
675
676 auto* promiseAccess = AccessMember(mASTData.mFrameAccessDeclRef, mASTData.mPromiseField);
677 auto* deallocPromise = AccessMember(promiseAccess, dtor, false);
678 auto* dtorCall = CallMemberFun(deallocPromise, dtor->getType());
679 deallocFuncBodyStmts.Add(dtorCall);
680
681 } else {
682 deallocFuncBodyStmts.Add(
683 Comment("promise_type is trivially destructible, no dtor required."sv));
684 }
685#endif
686
687 // This code isn't really there but it is the easiest and cleanest way to visualize the destruction of all
688 // member in the frame. The deallocation function:
689 // https://devblogs.microsoft.com/oldnewthing/20210331-00/?p=105028
691 mOutputFormatHelper.AppendCommentNewLine("This function invoked by coroutine_handle<>::destroy()"sv);
692
693 StmtsContainer deallocFuncBodyStmts{Comment("destroy all variables with dtors"sv)};
694
695 auto* dtorFuncDecl =
696 Function(StrCat("~"sv, GetName(*mASTData.mFrameType)), VoidTy(), {{CORO_FRAME_NAME, GetFramePointerType()}});
697 auto* deallocPromise = AccessMember(mASTData.mFrameAccessDeclRef, dtorFuncDecl);
698 auto* dtorCall = CallMemberFun(deallocPromise, GetFrameType());
699 deallocFuncBodyStmts.Add(dtorCall);
700
701 deallocFuncBodyStmts.Add(Comment("Deallocating the coroutine frame"sv));
702 deallocFuncBodyStmts.Add(
703 Comment("Note: The actual argument to delete is __builtin_coro_frame with the promise as parameter"sv));
704
705 deallocFuncBodyStmts.Add(stmt->getDeallocate());
706
707 SetFunctionBody(deallocFuncDecl, deallocFuncBodyStmts);
708 InsertArg(deallocFuncDecl);
709}
710//-----------------------------------------------------------------------------
711
712void CoroutinesCodeGenerator::InsertArg(const CoroutineBodyStmt* stmt)
713{
714 // insert a made up switch for continuing a resume
715 SwitchStmt* sstmt = Switch(mASTData.mSuspendIndexAccess);
716
717 // insert 0 with break for consistency
718 auto* initialSuspendCase = Case(0, Break());
719 StmtsContainer switchBodyStmts{initialSuspendCase};
720
721 for(const auto& i : NumberIterator{mSuspendsCounter}) {
722 switchBodyStmts.Add(Case(i + 1, Goto(BuildResumeLabelName(i + 1))));
723 }
724
725 auto* switchBody = mkCompoundStmt(switchBodyStmts);
726 sstmt->setBody(switchBody);
727
728 StmtsContainer funcBodyStmts{
729 Comment("Create a switch to get to the correct resume point"sv), sstmt, stmt->getInitSuspendStmt()};
730
731 // insert the init suspend expr
732 mState = eState::InitialSuspend;
733
734 if(mASTData.mThisExprs.size()) {
735 AddField(kwInternalThis, mASTData.mThisExprs.at(0)->getType());
736 }
737
738 mInsertVarDecl = false;
739 mSupressRecordDecls = true;
740
741 for(const auto* c : stmt->getBody()->children()) {
742 funcBodyStmts.Add(c);
743 }
744
745 if(const auto* coReturnVoid = dyn_cast_or_null<CoreturnStmt>(stmt->getFallthroughHandler())) {
746 funcBodyStmts.Add(coReturnVoid);
747 }
748
749 auto* gotoFinalSuspend = Goto(FINAL_SUSPEND_NAME);
750 funcBodyStmts.Add(gotoFinalSuspend);
751
752 auto* body = [&]() -> Stmt* {
753 auto* tryBody = mkCompoundStmt(funcBodyStmts);
754
755 // First open the try-catch block, as we get an error when jumping across such blocks with goto
756 if(const auto* exceptionHandler = stmt->getExceptionHandler()) {
757 // If we encounter an exceptionbefore inital_suspend's await_suspend was called we re-throw the
758 // exception.
759 auto* ifStmt = If(Not(mASTData.mInitialAwaitResumeCalledAccess), Throw());
760
761 StmtsContainer catchBodyStmts{ifStmt, exceptionHandler};
762
763 return Try(tryBody, Catch(catchBodyStmts));
764 }
765
766 return tryBody;
767 }();
768
769 InsertArg(body);
770
772
773 auto* finalSuspendLabel = Label(FINAL_SUSPEND_NAME);
774 InsertArg(finalSuspendLabel);
775 mState = eState::FinalSuspend;
776 InsertArg(stmt->getFinalSuspendStmt());
777
778 // disable prefixing names and types
779 mInsertVarDecl = true;
780}
781//-----------------------------------------------------------------------------
782
783void CoroutinesCodeGenerator::InsertArg(const CXXRecordDecl* stmt)
784{
785 if(not mSupressRecordDecls) {
787 }
788}
789//-----------------------------------------------------------------------------
791// We seem to need this, to peal of some static_casts in a CoroutineSuspendExpr.
792void CoroutinesCodeGenerator::InsertArg(const ImplicitCastExpr* stmt)
793{
794 if(mSupressCasts) {
795 InsertArg(stmt->getSubExpr());
796 } else {
798 }
799}
800//-----------------------------------------------------------------------------
801
802// A special hack to avoid having calls to __builtin_coro_xxx as some of them result in a crash
803// of the compiler and have assumption on the call order and function location.
804void CoroutinesCodeGenerator::InsertArg(const CallExpr* stmt)
805{
806 if(const auto* callee = dyn_cast_or_null<DeclRefExpr>(stmt->getCallee()->IgnoreCasts())) {
807 if(GetPlainName(*callee) == "__builtin_coro_frame"sv) {
809 return;
810
811 } else if(GetPlainName(*callee) == "__builtin_coro_free"sv) {
812 CodeGenerator::InsertArg(stmt->getArg(0));
813 return;
814
815 } else if(GetPlainName(*callee) == "__builtin_coro_size"sv) {
816 CodeGenerator::InsertArg(Sizeof(GetFrameType()));
817 return;
818 }
819 }
820
822}
823//-----------------------------------------------------------------------------
825static std::optional<std::string>
826FindValue(llvm::DenseMap<const Expr*, std::pair<const DeclRefExpr*, std::string>>& map, const Expr* key)
827{
828 if(const auto& s = map.find(key); s != map.end()) {
829 return s->second.second;
830 }
831
832 return {};
833}
834//-----------------------------------------------------------------------------
835
836void CoroutinesCodeGenerator::InsertArg(const OpaqueValueExpr* stmt)
837{
838 const auto* sourceExpr = stmt->getSourceExpr();
839
840 if(const auto& s = FindValue(mOpaqueValues, sourceExpr)) {
841 mOutputFormatHelper.Append(s.value());
842
843 } else {
844 // Needs to be internal because a user can create the same type and it gets put into the stack frame
845 std::string name{BuildSuspendVarName(stmt)};
846
847 // In case of a coroutine-template the same suspension point can occur multiple times. But to know when to add
848 // the _1 we must match the one from each instantiation. The DeclRefExpr is what distinguishes the same
849 // OpaqueValueExpr between multiple instantiations.
850 const auto* dref = FindDeclRef(sourceExpr);
851
852 // The initial_suspend and final_suspend expressions carry the same location info. If we hit such a case,
853 // make up another name.
854 // Below is a std::find_if. However, the same code looks unreadable with std::find_if
855 for(const auto lookupName{StrCat(CORO_FRAME_ACCESS, name)}; const auto& [k, value] : mOpaqueValues) {
856 if(auto [thisDeref, v] = value; (thisDeref == dref) and (v == lookupName)) {
857 name += "_1"sv;
858 break;
859 }
860 }
861
862 const auto accessName{StrCat(CORO_FRAME_ACCESS, name)};
863 mOpaqueValues.insert(std::make_pair(sourceExpr, std::make_pair(dref, accessName)));
864
865 OutputFormatHelper ofm{};
866 CoroutinesCodeGenerator codeGenerator{ofm, mPosBeforeFunc, mFSMName, mSuspendsCount, mASTData};
867
868 auto* promiseField = AddField(name, stmt->getType());
869 BinaryOperator* assignPromiseSuspend =
870 Assign(mASTData.mFrameAccessDeclRef, promiseField, stmt->getSourceExpr());
871
872 codeGenerator.InsertArg(assignPromiseSuspend);
873 ofm.AppendSemiNewLine();
874
875 ofm.SetIndent(mOutputFormatHelper);
876
877 mOutputFormatHelper.InsertAt(mPosBeforeSuspendExpr, ofm);
878 mOutputFormatHelper.Append(accessName);
879 }
880}
881//-----------------------------------------------------------------------------
882
883std::string CoroutinesCodeGenerator::BuildResumeLabelName(int index) const
884{
885 return StrCat(RESUME_LABEL_PREFIX, "_"sv, mFSMName, "_"sv, index);
886}
887//-----------------------------------------------------------------------------
888
889void CoroutinesCodeGenerator::InsertArg(const CoroutineSuspendExpr* stmt)
890{
892 InsertInstantiationPoint(GetGlobalAST().getSourceManager(), stmt->getKeywordLoc(), [&] {
893 if(isa<CoawaitExpr>(stmt)) {
894 return kwCoAwaitSpace;
895 } else {
896 return kwCoYieldSpace;
897 }
898 }());
899
900 mPosBeforeSuspendExpr = mOutputFormatHelper.CurrentPos();
901
902 /// Represents an expression that might suspend coroutine execution;
903 /// either a co_await or co_yield expression.
904 ///
905 /// Evaluation of this expression first evaluates its 'ready' expression. If
906 /// that returns 'false':
907 /// -- execution of the coroutine is suspended
908 /// -- the 'suspend' expression is evaluated
909 /// -- if the 'suspend' expression returns 'false', the coroutine is
910 /// resumed
911 /// -- otherwise, control passes back to the resumer.
912 /// If the coroutine is not suspended, or when it is resumed, the 'resume'
913 /// expression is evaluated, and its result is the result of the overall
914 /// expression.
915
916 // mOutputFormatHelper.AppendNewLine("// __builtin_coro_save() // frame->suspend_index = n");
917
918 // For why, see the implementation of CoroutinesCodeGenerator::InsertArg(const ImplicitCastExpr* stmt)
919 mSupressCasts = true;
920
921 auto* il = Int32(++mSuspendsCount);
922 auto* bop = Assign(mASTData.mSuspendIndexAccess, mASTData.mSuspendIndexField, il);
923
924 // Find out whether the return type is void or bool. In case of bool, we need to insert an if-statement, to
925 // suspend only, if the return value was true.
926 // Technically only void, bool, or std::coroutine_handle<Z> is allowed. [expr.await] p3.7
927 const bool returnsVoid{stmt->getSuspendExpr()->getType()->isVoidType()};
928
929 // XXX: check if getResumeExpr is marked noexcept. Otherwise we need additional expcetion handling?
930 // CGCoroutine.cpp:229
931
932 StmtsContainer bodyStmts{};
933 Expr* initializeInitialAwaitResume = nullptr;
934
935 const bool canThrow{[&] {
936 if(const auto* e = dyn_cast_or_null<ExprWithCleanups>(stmt->getSuspendExpr())) {
937 if(const auto* ce = dyn_cast_or_null<CallExpr>(e->getSubExpr())) {
938 if(const FunctionDecl* fd = ce->getDirectCallee()) {
939 if(const FunctionProtoType* fpt = fd->getType()->getAs<FunctionProtoType>()) {
940 return not fpt->isNothrow(/*ResultIfDependent=*/false);
941 }
942 }
943 }
944 }
945
946 return true;
947 }()};
948
949 auto addInitialAwaitSuspendCalled = [&] {
950 if(eState::InitialSuspend == mState) {
951 mState = eState::Body;
952 // https://timsong-cpp.github.io/cppwp/n4861/dcl.fct.def.coroutine#5.3
953 initializeInitialAwaitResume =
954 Assign(mASTData.mFrameAccessDeclRef, mASTData.mInitialAwaitResumeCalledField, Bool(true));
955 bodyStmts.Add(initializeInitialAwaitResume);
956 }
957 };
958
959 auto insertTryCatchIfNecessary = [&](StmtsContainer& cont) {
960 if(canThrow) {
961 auto* tryBody = mkCompoundStmt(cont);
962
963 StmtsContainer catchBodyStmts{
964 Assign(mASTData.mSuspendIndexAccess, mASTData.mSuspendIndexField, Int32(mSuspendsCount - 1)), Throw()};
965
966 cont.clear();
967 cont.Add(Try(tryBody, Catch(catchBodyStmts)));
968 }
969 };
970
971 if(returnsVoid) {
972 bodyStmts.Add(bop);
973 bodyStmts.Add(stmt->getSuspendExpr());
974
975 insertTryCatchIfNecessary(bodyStmts);
976
977 addInitialAwaitSuspendCalled();
978 bodyStmts.Add(Return());
979
980 InsertArg(If(Not(stmt->getReadyExpr()), bodyStmts));
981
982 } else {
983 addInitialAwaitSuspendCalled();
984 bodyStmts.Add(Return());
985
986 auto* ifSuspend = If(stmt->getSuspendExpr(), bodyStmts);
987
988 StmtsContainer innerBodyStmts{bop, ifSuspend};
989 insertTryCatchIfNecessary(innerBodyStmts);
990
991 InsertArg(If(Not(stmt->getReadyExpr()), innerBodyStmts));
992 }
993
994 if(not returnsVoid and initializeInitialAwaitResume) {
995 // At this point we technically haven't called initial suspend
996 InsertArgWithNull(initializeInitialAwaitResume);
997 mOutputFormatHelper.AppendNewLine();
998 }
999
1000 auto* suspendLabel = Label(BuildResumeLabelName(mSuspendsCount));
1001 InsertArg(suspendLabel);
1002
1003 if(eState::FinalSuspend == mState) {
1004 auto* memExpr = AccessMember(mASTData.mFrameAccessDeclRef, mASTData.mDestroyFnField, true);
1005 auto* callCoroFSM = Call(memExpr, {mASTData.mFrameAccessDeclRef});
1006 InsertArg(callCoroFSM);
1007 return;
1008 }
1009
1010 const auto* resumeExpr = stmt->getResumeExpr();
1011
1012 if(not resumeExpr->getType()->isVoidType()) {
1013 const auto* sourceExpr = stmt->getOpaqueValue()->getSourceExpr();
1014
1015 if(const auto& s = FindValue(mOpaqueValues, sourceExpr)) {
1016 const auto fieldName{StrCat(std::string_view{s.value()}.substr(CORO_FRAME_ACCESS.size()), "_res"sv)};
1017 mOutputFormatHelper.Append(CORO_FRAME_ACCESS, fieldName, hlpAssing);
1018
1019 AddField(fieldName, resumeExpr->getType());
1020 }
1021 }
1022
1023 InsertArg(resumeExpr);
1024}
1025//-----------------------------------------------------------------------------
1026
1027void CoroutinesCodeGenerator::InsertArg(const CoreturnStmt* stmt)
1028{
1029 InsertInstantiationPoint(GetGlobalAST().getSourceManager(), stmt->getKeywordLoc(), kwCoReturnSpace);
1030
1031 if(stmt->getPromiseCall()) {
1032 InsertArg(stmt->getPromiseCall());
1033
1034 if(stmt->isImplicit()) {
1035 mOutputFormatHelper.AppendComment("implicit"sv);
1036 }
1037 }
1038}
1039//-----------------------------------------------------------------------------
1040
1041void CoroutinesCodeGenerator::InsertArgWithNull(const Stmt* stmt)
1042{
1043 InsertArg(stmt);
1044 InsertArg(mkNullStmt());
1045}
1046//-----------------------------------------------------------------------------
1047
1048} // namespace clang::insights
const ASTContext & GetGlobalAST()
Get access to the ASTContext.
Definition Insights.cpp:71
constexpr std::string_view hlpDestroyFn
constexpr std::string_view kwCoReturnSpace
constexpr std::string_view kwInternalThis
constexpr std::string_view hlpResumeFn
constexpr std::string_view hlpAssing
#define RETURN_IF(cond)
! A helper inspired by https://github.com/Microsoft/wil/wiki/Error-handling-helpers
virtual void InsertArg(const Decl *stmt)
void InsertInstantiationPoint(const SourceManager &sm, const SourceLocation &instLoc, std::string_view text={})
Inserts the instantiation point of a template.
void InsertTemplateArg(const TemplateArgument &arg)
OutputFormatHelper & mOutputFormatHelper
A special container which creates either a CodeGenerator or a CfrontCodeGenerator depending on the co...
Find a SuspendsExpr's in a coroutine body statement for early transformation.
CoroutineASTTransformer(CoroutineASTData &coroutineASTData, size_t &suspendsCounter, Stmt *stmt, llvm::DenseMap< VarDecl *, MemberExpr * > varNamePrefix, Stmt *prev=nullptr)
A special generator for coroutines. It is only activated, if -show-coroutines-transformation is given...
void InsertArg(const ImplicitCastExpr *stmt) override
void InsertCoroutine(const FunctionDecl &fd, const CoroutineBodyStmt *body)
void OpenScope()
Open a scope by inserting a '{' followed by an indented newline.
void AppendNewLine(const char c)
Same as Append but adds a newline after the last argument.
void Append(const char c)
Append a single character.
void CloseScope(const NoNewLineBefore newLineBefore=NoNewLineBefore::No)
Close a scope by inserting a '}'.
void AppendCommentNewLine(const std::string_view &arg)
void AppendSemiNewLine()
Append a semicolon and a newline.
void InsertAt(const size_t atPos, std::string_view data)
Insert a string before the position atPos.
BinaryOperator * Equal(Expr *var, Expr *assignExpr)
UnaryOperator * AddrOf(const Expr *stmt)
CXXReinterpretCastExpr * ReinterpretCast(QualType toType, const Expr *toExpr, bool makePointer)
CXXRecordDecl * Struct(std::string_view name)
FieldDecl * mkFieldDecl(DeclContext *dc, std::string_view name, QualType type)
DeclRefExpr * mkDeclRefExpr(const ValueDecl *vd)
CXXNewExpr * New(ArrayRef< Expr * > placementArgs, const Expr *expr, QualType t)
std::vector< std::pair< std::string_view, QualType > > params_vector
Definition ASTHelpers.h:27
CallExpr * Call(const FunctionDecl *fd, ArrayRef< Expr * > params)
UnaryOperator * Ref(const Expr *e)
VarDecl * Variable(std::string_view name, QualType type, DeclContext *dc)
MemberExpr * AccessMember(const Expr *expr, const ValueDecl *vd, bool isArrow)
DeclRefExpr * mkVarDeclRefExpr(std::string_view name, QualType type)
ReturnStmt * Return(Expr *stmt)
CaseStmt * Case(int value, Stmt *stmt)
Stmt * Comment(std::string_view comment)
IfStmt * If(const Expr *condition, ArrayRef< Stmt * > bodyStmts)
GotoStmt * Goto(std::string_view labelName)
CXXCatchStmt * Catch(ArrayRef< Stmt * > body)
CXXThrowExpr * Throw(const Expr *expr)
void ReplaceNode(Stmt *parent, Stmt *oldNode, Stmt *newNode)
UnaryExprOrTypeTraitExpr * Sizeof(QualType toType)
LabelStmt * Label(std::string_view name)
UnaryOperator * Not(const Expr *stmt)
CXXBoolLiteralExpr * Bool(bool b)
FunctionDecl * Function(std::string_view name, QualType returnType, const params_vector &parameters)
CompoundStmt * mkCompoundStmt(ArrayRef< Stmt * > bodyStmts, SourceLocation beginLoc, SourceLocation endLoc)
SwitchStmt * Switch(Expr *stmt)
IntegerLiteral * Int32(uint64_t value)
ParmVarDecl * Parameter(const FunctionDecl *fd, std::string_view name, QualType type)
UnaryOperator * Dref(const Expr *stmt)
QualType Ptr(QualType srcType)
CXXStaticCastExpr * StaticCast(QualType toType, const Expr *toExpr, bool makePointer)
CXXMemberCallExpr * CallMemberFun(Expr *memExpr, QualType retType)
BinaryOperator * Assign(const VarDecl *var, Expr *assignExpr)
CXXTryStmt * Try(const Stmt *tryBody, CXXCatchStmt *catchAllBody)
const std::string SUSPEND_INDEX_NAME
const std::string RESUME_LABEL_PREFIX
bool Contains(const std::string_view source, const std::string_view search)
const DeclRefExpr * FindDeclRef(const Stmt *stmt)
Go deep in a Stmt if necessary and look to all childs for a DeclRefExpr.
std::string GetPlainName(const DeclRefExpr &DRE)
static auto * CreateCoroFunctionDecl(std::string funcName, QualType type)
static void SetFunctionBody(FunctionDecl *fd, StmtsContainer &bodyStmts)
constexpr std::string_view CORO_FRAME_NAME
std::string MakeLineColumnName(const SourceManager &sm, const SourceLocation &loc, const std::string_view &prefix)
std::string GetName(const NamedDecl &nd, const QualifiedName qualifiedName)
void ReplaceAll(std::string &str, std::string_view from, std::string_view to)
std::string BuildInternalVarName(const std::string_view &varName)
const std::string CORO_FRAME_ACCESS_THIS
std::string BuildTemplateParamObjectName(std::string name)
const std::string FINAL_SUSPEND_NAME
const std::string INITIAL_AWAIT_SUSPEND_CALLED_NAME
static std::string BuildSuspendVarName(const OpaqueValueExpr *stmt)
static std::optional< std::string > FindValue(llvm::DenseMap< const Expr *, std::pair< const DeclRefExpr *, std::string > > &map, const Expr *key)
void EnableGlobalInsert(GlobalInserts idx)
Definition Insights.cpp:96
const std::string CORO_FRAME_ACCESS
static FieldDecl * AddField(CoroutineASTData &astData, std::string_view name, QualType type)
std::string StrCat(const auto &... args)
std::vector< const CXXThisExpr * > mThisExprs
! A helper type to have a container for ArrayRef
Definition ASTHelpers.h:64