CoroutinesCodeGenerator.cpp
Go to the documentation of this file.
1/******************************************************************************
2 *
3 * C++ Insights, copyright (C) by Andreas Fertig
4 * Distributed under an MIT license. See LICENSE for details
5 *
6 ****************************************************************************/
7
8#include <iterator>
9#include <optional>
10#include <vector>
11#include "ASTHelpers.h"
12#include "CodeGenerator.h"
13#include "DPrint.h"
14#include "Insights.h"
15#include "InsightsHelpers.h"
16#include "NumberIterator.h"
17
18#include <algorithm>
19//-----------------------------------------------------------------------------
20
21namespace ranges = std::ranges;
22//-----------------------------------------------------------------------------
23
24namespace clang::insights {
25
26constexpr std::string_view CORO_FRAME_NAME{"__f"sv};
27const std::string CORO_FRAME_ACCESS{StrCat(CORO_FRAME_NAME, "->"sv)};
29const std::string SUSPEND_INDEX_NAME{BuildInternalVarName("suspend_index"sv)};
30const std::string INITIAL_AWAIT_SUSPEND_CALLED_NAME{BuildInternalVarName("initial_await_suspend_called"sv)};
31const std::string RESUME_LABEL_PREFIX{BuildInternalVarName("resume"sv)};
32const std::string FINAL_SUSPEND_NAME{BuildInternalVarName("final_suspend"sv)};
33//-----------------------------------------------------------------------------
34
35using namespace asthelpers;
36//-----------------------------------------------------------------------------
37
38QualType CoroutinesCodeGenerator::GetFramePointerType() const
39{
40 return Ptr(GetFrameType());
41}
42//-----------------------------------------------------------------------------
43
45{
46 RETURN_IF(not(mASTData.mFrameType and mASTData.mDoInsertInDtor));
47
48 mASTData.mFrameType->completeDefinition();
49
51
52 // Using the "normal" CodeGenerator here as this is only about inserting the made up coroutine-frame.
53 CodeGeneratorVariant codeGenerator{ofm};
54 codeGenerator->InsertArg(mASTData.mFrameType);
55
56 // Insert the made-up struct before the function declaration
57 mOutputFormatHelper.InsertAt(mPosBeforeFunc, ofm);
58}
59//-----------------------------------------------------------------------------
60
61static FieldDecl* AddField(CoroutineASTData& astData, std::string_view name, QualType type)
62{
63 if(nullptr == astData.mFrameType) {
64 return nullptr;
65 }
66
67 auto* fieldDecl = mkFieldDecl(astData.mFrameType, name, type);
68
69 astData.mFrameType->addDecl(fieldDecl);
70
71 return fieldDecl;
72}
73//-----------------------------------------------------------------------------
74
75FieldDecl* CoroutinesCodeGenerator::AddField(std::string_view name, QualType type)
76{
77 return ::clang::insights::AddField(mASTData, name, type);
78}
79//-----------------------------------------------------------------------------
80
81static auto* CreateCoroFunctionDecl(std::string funcName, QualType type)
82{
83 params_vector params{{CORO_FRAME_NAME, type}};
84 const std::string coroFsmName{BuildInternalVarName(funcName)};
85
86 return Function(coroFsmName, VoidTy(), params);
87}
88//-----------------------------------------------------------------------------
89
90static void SetFunctionBody(FunctionDecl* fd, StmtsContainer& bodyStmts)
91{
92 fd->setBody(mkCompoundStmt(bodyStmts));
93}
94//-----------------------------------------------------------------------------
95
96static std::string BuildSuspendVarName(const OpaqueValueExpr* stmt)
97{
99 MakeLineColumnName(GetGlobalAST().getSourceManager(), stmt->getSourceExpr()->getBeginLoc(), "suspend_"sv));
100}
101//-----------------------------------------------------------------------------
102
103/// \brief Find a \c SuspendsExpr's in a coroutine body statement for early transformation.
104///
105/// Traverse the whole CoroutineBodyStmt to find all appearing \c VarDecl. These need to be rerouted to the
106/// coroutine frame and hence prefixed by something like __f->. For that reason we only look for \c VarDecls
107/// directly appearing in the body, \c CallExpr will be skipped.
108class CoroutineASTTransformer : public StmtVisitor<CoroutineASTTransformer>
109{
110 StmtsContainer mBodyStmts{};
111 Stmt* mPrevStmt{}; // used to insert the suspendexpr
112 CoroutineASTData& mASTData;
113 Stmt* mStaged{};
114 bool mSkip{};
115 size_t& mSuspendsCount;
116 llvm::DenseMap<VarDecl*, MemberExpr*> mVarNamePrefix{};
117
118public:
120 size_t& suspendsCounter,
121 Stmt* stmt,
122 llvm::DenseMap<VarDecl*, MemberExpr*> varNamePrefix,
123 Stmt* prev = nullptr)
124 : mPrevStmt{prev}
125 , mASTData{coroutineASTData}
126 , mSuspendsCount{suspendsCounter}
127 , mVarNamePrefix{varNamePrefix}
128 {
129 if(nullptr == mPrevStmt) {
130 mPrevStmt = stmt;
131 }
132
133 Visit(stmt);
134 }
135
136 void Visit(Stmt* stmt)
137 {
138 if(stmt) {
140 }
141 }
142
143 void VisitCompoundStmt(CompoundStmt* stmt)
144 {
145 for(auto* child : stmt->body()) {
146 mStaged = child;
147 Visit(child);
148
149 if(not mSkip) {
150 mBodyStmts.Add(child);
151 }
152
153 mSkip = false;
154 }
155
156 auto* comp = mkCompoundStmt(mBodyStmts);
157
158 ReplaceNode(mPrevStmt, stmt, comp);
159
160 mBodyStmts.clear();
161 }
162
163 void VisitSwitchStmt(SwitchStmt* stmt)
164 {
165 Visit(stmt->getCond());
166
167 CoroutineASTTransformer{mASTData, mSuspendsCount, stmt->getBody(), mVarNamePrefix, stmt};
168 }
169
170 void VisitDoStmt(DoStmt* stmt)
171 {
172 Visit(stmt->getCond());
173
174 CoroutineASTTransformer{mASTData, mSuspendsCount, stmt->getBody(), mVarNamePrefix, stmt};
175 }
176
177 void VisitWhileStmt(WhileStmt* stmt)
178 {
179 Visit(stmt->getCond());
180
181 CoroutineASTTransformer{mASTData, mSuspendsCount, stmt->getBody(), mVarNamePrefix, stmt};
182 }
183
184 void VisitIfStmt(IfStmt* stmt)
185 {
186 Visit(stmt->getCond());
187
188 CoroutineASTTransformer{mASTData, mSuspendsCount, stmt->getThen(), mVarNamePrefix, stmt};
189
190 CoroutineASTTransformer{mASTData, mSuspendsCount, stmt->getElse(), mVarNamePrefix, stmt};
191 }
192
193 void VisitForStmt(ForStmt* stmt)
194 {
195 // technically because of the init the entire for loop should be put into a dedicated scope
196 Visit(stmt->getInit());
197
198 // Special case. A VarDecl in init will be added to the body of the function and the actual init is left
199 // untouched. Work some magic to put it in the right place.
200 if(mSkip) {
201 auto* oldInit = stmt->getInit();
202 auto* newInit = mBodyStmts.mStmts.back();
203 mBodyStmts.mStmts.pop_back();
204
205 ReplaceNode(stmt, oldInit, newInit);
206
207 mSkip = false;
208 }
209
210 Visit(stmt->getCond());
211
212 Visit(stmt->getInc());
213
214 CoroutineASTTransformer{mASTData, mSuspendsCount, stmt->getBody(), mVarNamePrefix, stmt};
215
216 mSkip = false;
217 }
218
219 void VisitCXXForRangeStmt(CXXForRangeStmt* stmt)
220 {
221 Visit(stmt->getRangeStmt());
222
223 // ignoring the loop variable should be fine.
224
225 CoroutineASTTransformer{mASTData, mSuspendsCount, stmt->getBody(), mVarNamePrefix, stmt};
226 }
227
228 void VisitDeclRefExpr(DeclRefExpr* stmt)
229 {
230 if(auto* vd = dyn_cast_or_null<VarDecl>(stmt->getDecl())) {
231 RETURN_IF(not vd->isLocalVarDeclOrParm() or vd->isStaticLocal() or not Contains(mVarNamePrefix, vd));
232
233 auto* memberExpr = mVarNamePrefix[vd];
234
235 ReplaceNode(mPrevStmt, stmt, memberExpr);
236 }
237 }
238
239 void VisitDeclStmt(DeclStmt* stmt)
240 {
241 for(auto* decl : stmt->decls()) {
242 if(auto* varDecl = dyn_cast_or_null<VarDecl>(decl)) {
243 if(varDecl->isStaticLocal()) {
244 continue;
245 }
246
247 // add this point a placement-new would be appropriate for at least some cases.
248
249 auto* field = AddField(mASTData, GetName(*varDecl), varDecl->getType());
250 auto* me = AccessMember(mASTData.mFrameAccessDeclRef, field);
251 auto* assign = Assign(me, field, varDecl->getInit());
252
253 mVarNamePrefix.insert(std::make_pair(varDecl, me));
254
255 Visit(varDecl->getInit());
256
257 mSkip = true;
258 mBodyStmts.Add(assign);
259
260 } else if(const auto* recordDecl = dyn_cast_or_null<CXXRecordDecl>(decl)) {
261 mASTData.mFrameType->addDecl(const_cast<CXXRecordDecl*>(recordDecl));
262 }
263 }
264 }
265
266 void VisitCXXThisExpr(CXXThisExpr* stmt)
267 {
268 auto* fieldDecl = mkFieldDecl(mASTData.mFrameType, kwInternalThis, stmt->getType());
269 auto* indirectThisMemberExpr = AccessMember(mASTData.mFrameAccessDeclRef, fieldDecl);
270
271 ReplaceNode(mPrevStmt, stmt, indirectThisMemberExpr);
272
273 if(0 == mASTData.mThisExprs.size()) {
274 mASTData.mThisExprs.push_back(stmt);
275 }
276 }
277
278 void VisitCallExpr(CallExpr* stmt)
279 {
280 auto* tmp = mPrevStmt;
281 mPrevStmt = stmt;
282
283 for(auto* arg : stmt->arguments()) {
284 Visit(arg);
285 }
286
287 mPrevStmt = tmp;
288 }
289
290 void VisitCXXMemberCallExpr(CXXMemberCallExpr* stmt)
291 {
292 auto* tmp = mPrevStmt;
293 mPrevStmt = stmt->getCallee();
294
295 Visit(stmt->getCallee());
296
297 mPrevStmt = tmp;
298
300 }
301
302 void VisitCoreturnStmt(CoreturnStmt* stmt)
303 {
304 Visit(stmt->getOperand());
305 Visit(stmt->getPromiseCall());
306 }
307
308 void VisitCoyieldExpr(CoyieldExpr* stmt)
309 {
310 ++mSuspendsCount;
311
312 if(isa<ExprWithCleanups>(mStaged)) {
313 mBodyStmts.Add(stmt);
314 mSkip = true;
315 }
316
317 Visit(stmt->getOperand());
318 }
319
320 void VisitCoawaitExpr(CoawaitExpr* stmt)
321 {
322 ++mSuspendsCount;
323
324 if(const bool returnsVoid{stmt->getResumeExpr()->getType()->isVoidType()}; returnsVoid) {
325 Visit(stmt->getOperand());
326
327 // in the void return case there is nothing to do, because this expression (potentially) is not nested.
328 return;
329 }
330
331 mBodyStmts.Add(stmt);
332
333 // Note: Add the this pointer to the name isn't the best but s quick approach
334 const std::string name{StrCat(CORO_FRAME_ACCESS, BuildSuspendVarName(stmt->getOpaqueValue()), "_res"sv)};
335
336 auto* resultVar = Variable(name, stmt->getType());
337 auto* resultVarDeclRef = mkDeclRefExpr(resultVar);
338
339 ReplaceNode(mPrevStmt, stmt, resultVarDeclRef);
340
341 Visit(stmt->getCommonExpr());
342 Visit(stmt->getOperand());
343 Visit(stmt->getSuspendExpr());
344 Visit(stmt->getReadyExpr());
345 Visit(stmt->getResumeExpr());
346 }
347
348 void VisitCoroutineBodyStmt(CoroutineBodyStmt* stmt)
349 {
350 auto* varDecl = stmt->getPromiseDecl();
351
352 mASTData.mPromiseField = AddField(mASTData, GetName(*varDecl), varDecl->getType());
353 auto* me = AccessMember(mASTData.mFrameAccessDeclRef, mASTData.mPromiseField);
354
355 mVarNamePrefix.insert(std::make_pair(varDecl, me));
356
357 auto& ctx = GetGlobalAST();
358
359 // add the suspend index variable
360 mASTData.mSuspendIndexField = AddField(mASTData, SUSPEND_INDEX_NAME, ctx.IntTy);
362
363 // https://timsong-cpp.github.io/cppwp/n4861/dcl.fct.def.coroutine#5.3
367
368 for(auto* param : stmt->getParamMoves()) {
369 if(auto* declStmt = dyn_cast_or_null<DeclStmt>(param)) {
370 if(auto* varDecl2 = dyn_cast_or_null<VarDecl>(declStmt->getSingleDecl())) {
371 // For the captured parameters we need to find the ParmVarDecl instead of the newly created VarDecl
372 if(auto* declRef = FindDeclRef(varDecl2->getAnyInitializer())) {
373 auto* varDecl = dyn_cast<ParmVarDecl>(declRef->getDecl());
374
375 auto* field = AddField(mASTData, GetName(*varDecl), varDecl->getType());
376 auto* me = AccessMember(mASTData.mFrameAccessDeclRef, field);
377
378 mVarNamePrefix.insert(std::make_pair(const_cast<ParmVarDecl*>(varDecl), me));
379 }
380 }
381 }
382 }
383
384 Visit(stmt->getBody());
385
386 Visit(stmt->getReturnStmt());
387 Visit(stmt->getReturnValue());
388 Visit(stmt->getReturnValueInit());
389 Visit(stmt->getExceptionHandler());
390 Visit(stmt->getReturnStmtOnAllocFailure());
391 Visit(stmt->getFallthroughHandler());
392 Visit(stmt->getInitSuspendStmt());
393 Visit(stmt->getFinalSuspendStmt());
394 }
395
396 void VisitStmt(Stmt* stmt)
397 {
398 auto* tmp = mPrevStmt;
399 mPrevStmt = stmt;
400
401 for(auto* child : stmt->children()) {
402 Visit(child);
403 }
404
405 mPrevStmt = tmp;
406 }
407};
408//-----------------------------------------------------------------------------
409
410void CoroutinesCodeGenerator::InsertCoroutine(const FunctionDecl& fd, const CoroutineBodyStmt* stmt)
411{
413
414 auto& ctx = GetGlobalAST();
415
416 mFSMName = [&] {
417 OutputFormatHelper ofm{};
418 CodeGeneratorVariant codeGenerator{ofm};
419
420 // Coroutines can be templates and then we end up with the same FSM name but different template parameters.
421 // XXX: This will fail with NTTP's like 3.14
422 if(const auto* args = fd.getTemplateSpecializationArgs()) {
423 ofm.Append('_');
424
425 for(OnceFalse needsUnderscore{}; const auto& arg : args->asArray()) {
426 if(needsUnderscore) {
427 ofm.Append('_');
428 }
429
430 codeGenerator->InsertTemplateArg(arg);
431 }
432 }
433
434 auto str = std::move(ofm.GetString());
435 ReplaceAll(str, "<"sv, ""sv);
436 ReplaceAll(str, ":"sv, ""sv);
437 ReplaceAll(str, ">"sv, ""sv);
438
440
441 if(fd.isOverloadedOperator()) {
442 return StrCat(MakeLineColumnName(ctx.getSourceManager(), stmt->getBeginLoc(), "operator_"sv), str);
443 } else {
444 return StrCat(GetName(fd), str);
445 }
446 }();
447
448 mFrameName = BuildInternalVarName(StrCat(mFSMName, "Frame"sv));
449
450 // Insert a made up struct which holds the "captured" parameters stored in the coroutine frame
451 mASTData.mFrameType = Struct(mFrameName);
452 mASTData.mFrameAccessDeclRef = mkVarDeclRefExpr(CORO_FRAME_NAME, GetFrameType());
453
454 // The coroutine frame starts with two function pointers to the resume and destroy function. See:
455 // https://gcc.gnu.org/legacy-ml/gcc-patches/2020-01/msg01096.html:
456 // "The ABI mandates that pointers into the coroutine frame point to an area
457 // begining with two function pointers (to the resume and destroy functions
458 // described below); these are immediately followed by the "promise object"
459 // described in the standard."
460 //
461 // and
462 // https://llvm.org/docs/Coroutines.html#id72 "Coroutine Representation"
463 auto* resumeFnFd = Function(hlpResumeFn, VoidTy(), {{CORO_FRAME_NAME, GetFramePointerType()}});
464 auto resumeFnType = Ptr(resumeFnFd->getType());
465 mASTData.mResumeFnField = AddField(hlpResumeFn, resumeFnType);
466
467 auto* destroyFnFd = Function(hlpDestroyFn, VoidTy(), {{CORO_FRAME_NAME, GetFramePointerType()}});
468 auto destroyFnType = Ptr(destroyFnFd->getType());
469 mASTData.mDestroyFnField = AddField(hlpDestroyFn, destroyFnType);
470
471 // Allocated the made up frame
472 mOutputFormatHelper.AppendCommentNewLine("Allocate the frame including the promise"sv);
473 mOutputFormatHelper.AppendCommentNewLine("Note: The actual parameter new is __builtin_coro_size"sv);
474
475 auto* coroFrameVar = Variable(CORO_FRAME_NAME, GetFramePointerType());
476 auto* reicast = ReinterpretCast(GetFramePointerType(), stmt->getAllocate());
477
478 coroFrameVar->setInit(reicast);
479
480 InsertArg(coroFrameVar);
481
482 // P0057R8: [dcl.fct.def.coroutine] p8: get_return_object_on_allocation_failure indicates that new may return a
483 // nullptr. In this case return get_return_object_on_allocation_failure.
484 if(stmt->getReturnStmtOnAllocFailure()) {
485 auto* nptr = new(ctx) CXXNullPtrLiteralExpr({});
486
487 // Create an IfStmt.
488 StmtsContainer bodyStmts{stmt->getReturnStmtOnAllocFailure()};
489 auto* ifStmt = If(Equal(nptr, mASTData.mFrameAccessDeclRef), bodyStmts);
490
492 InsertArg(ifStmt);
493 }
494
496 mASTData, mSuspendsCounter, const_cast<CoroutineBodyStmt*>(stmt), llvm::DenseMap<VarDecl*, MemberExpr*>{}};
497
498 // set initial suspend count to zero.
499 auto* setSuspendIndexToZero = Assign(mASTData.mFrameAccessDeclRef, mASTData.mSuspendIndexField, Int32(0));
500 InsertArgWithNull(setSuspendIndexToZero);
501
502 // https://timsong-cpp.github.io/cppwp/n4861/dcl.fct.def.coroutine#5.3
503 auto* initializeInitialAwaitResume =
505 InsertArgWithNull(initializeInitialAwaitResume);
506
507 // Move the parameters first
508 for(auto* param : stmt->getParamMoves()) {
509 if(const auto* declStmt = dyn_cast_or_null<DeclStmt>(param)) {
510 if(const auto* varDecl = dyn_cast_or_null<VarDecl>(declStmt->getSingleDecl())) {
511 const auto varName = GetName(*varDecl);
512
514 varName,
515 " = "sv,
516 "std::forward<"sv,
517 GetName(varDecl->getType()),
518 ">("sv,
519 varName,
520 ");"sv);
521 }
522 }
523 }
524
525 // According to https://eel.is/c++draft/dcl.fct.def.coroutine#5.7 the promise_type constructor can have
526 // parameters. If so, they must be equal to the coroutines function parameters.
527 // The code here performs a _simple_ lookup for a matching ctor without using Clang's overload resolution.
528 ArrayRef<ParmVarDecl*> funParams = fd.parameters();
529 SmallVector<ParmVarDecl*, 16> funParamStorage{};
530 QualType cxxMethodType{};
531
532 if(const auto* cxxMethodDecl = dyn_cast_or_null<CXXMethodDecl>(&fd)) {
533 funParamStorage.reserve(funParams.size() + 1);
534
535 cxxMethodType = cxxMethodDecl->getFunctionObjectParameterType();
536
537 // In case we have a member function the first parameter is a reference to this. The following code injects
538 // this parameter.
539 funParamStorage.push_back(Parameter(&fd, CORO_FRAME_ACCESS_THIS, cxxMethodType));
540
541 ranges::copy(funParams, std::back_inserter(funParamStorage));
542
543 funParams = funParamStorage;
544 }
545
546 auto getNonRefType = [&](auto* var) -> QualType {
547 if(const auto* et = var->getType().getNonReferenceType()->template getAs<ElaboratedType>()) {
548 return et->getNamedType();
549 } else {
550 return QualType(var->getType().getNonReferenceType().getTypePtrOrNull(), 0);
551 }
552 };
553
554 SmallVector<Expr*, 16> exprs{};
555
556 for(auto* promiseTypeRecordDecl = mASTData.mPromiseField->getType()->getAsCXXRecordDecl();
557 auto* ctor : promiseTypeRecordDecl->ctors()) {
558
559 if(not ranges::equal(
560 ctor->parameters(), funParams, [&](auto& a, auto& b) { return getNonRefType(a) == getNonRefType(b); })) {
561 continue;
562 }
563
564 // In case of a promise ctor which takes this as the first argument, that parameter needs to be deferences,
565 // as it can only be taken as a reference.
566 OnceTrue derefFirstParam{};
567
568 if(not ctor->param_empty() and
569 (getNonRefType(ctor->getParamDecl(0)) == QualType(cxxMethodType.getTypePtrOrNull(), 0))) {
570 if(0 == mASTData.mThisExprs.size()) {
571 mASTData.mThisExprs.push_back(CXXThisExpr::Create(ctx, {}, Ptr(cxxMethodType), false));
572 }
573 } else {
574 (void)static_cast<bool>(derefFirstParam); // set it to false
575 }
576
577 for(const auto& fparam : funParams) {
578 if(derefFirstParam) {
579 exprs.push_back(Dref(mkDeclRefExpr(fparam)));
580
581 } else {
582 exprs.push_back(AccessMember(mASTData.mFrameAccessDeclRef, fparam));
583 }
584 }
585
586 if(funParams.size()) {
587 // The <new> header needs to be included.
589 }
590
591 break; // We've found what we were looking for
592 }
593
594 if(mASTData.mThisExprs.size()) {
596 }
597
598 // Now call the promise ctor, as it may access some of the parameters it comes at this point.
600 mOutputFormatHelper.AppendCommentNewLine("Construct the promise."sv);
601 auto* me = AccessMember(mASTData.mFrameAccessDeclRef, mASTData.mPromiseField);
602
603 auto* ctorArgs = new(ctx) InitListExpr{ctx, {}, exprs, {}};
604
605 CXXNewExpr* newFrame = New({AddrOf(me)}, ctorArgs, mASTData.mPromiseField->getType());
606
607 InsertArgWithNull(newFrame);
608
609 // Add parameters from the original function to the list
610
611 // P0057R8: [dcl.fct.def.coroutine] p5: before initial_suspend and at tops 1
612
613 // Make a call to the made up state machine function for the initial suspend
615
616 // [dcl.fct.def.coroutine]
617 mOutputFormatHelper.AppendCommentNewLine("Forward declare the resume and destroy function."sv);
618
619 auto* fsmFuncDecl = CreateCoroFunctionDecl(StrCat(mFSMName, "Resume"sv), GetFramePointerType());
620 InsertArg(fsmFuncDecl);
621 auto* deallocFuncDecl = CreateCoroFunctionDecl(StrCat(mFSMName, "Destroy"sv), GetFramePointerType());
622 InsertArg(deallocFuncDecl);
623
625
626 mOutputFormatHelper.AppendCommentNewLine("Assign the resume and destroy function pointers."sv);
627
628 auto* assignResumeFn = Assign(mASTData.mFrameAccessDeclRef, mASTData.mResumeFnField, Ref(fsmFuncDecl));
629 InsertArgWithNull(assignResumeFn);
630
631 auto* assignDestroyFn = Assign(mASTData.mFrameAccessDeclRef, mASTData.mDestroyFnField, Ref(deallocFuncDecl));
632 InsertArgWithNull(assignDestroyFn);
634
636 R"A(Call the made up function with the coroutine body for initial suspend.
637 This function will be called subsequently by coroutine_handle<>::resume()
638 which calls __builtin_coro_resume(__handle_))A"sv);
639
640 auto* callCoroFSM = Call(fsmFuncDecl, {mASTData.mFrameAccessDeclRef});
641 InsertArgWithNull(callCoroFSM);
642
645
646 InsertArg(stmt->getReturnStmt());
647
649
650 mOutputFormatHelper.CloseScope(OutputFormatHelper::NoNewLineBefore::Yes);
653
654 // add contents of the original function to the body of our made up function
655 StmtsContainer fsmFuncBodyStmts{stmt};
656
657 mOutputFormatHelper.AppendCommentNewLine("This function invoked by coroutine_handle<>::resume()"sv);
658 SetFunctionBody(fsmFuncDecl, fsmFuncBodyStmts);
659 InsertArg(fsmFuncDecl);
660
661 mASTData.mDoInsertInDtor = true; // As we have a coroutine insert the frame when this object goes out of scope.
662
663#if 0 // Preserve for later. Technically the destructor for the entire frame that's made up below takes care of
664 // everything.
665
666 // A destructor is only present, if they promise_type or one of its members is non-trivially destructible.
667 if(auto* dtor = mASTData.mPromiseField->getType()->getAsCXXRecordDecl()->getDestructor()) {
668 deallocFuncBodyStmts.Add(Comment("Deallocating the coroutine promise type"sv));
669
670 auto* promiseAccess = AccessMember(mASTData.mFrameAccessDeclRef, mASTData.mPromiseField);
671 auto* deallocPromise = AccessMember(promiseAccess, dtor, false);
672 auto* dtorCall = CallMemberFun(deallocPromise, dtor->getType());
673 deallocFuncBodyStmts.Add(dtorCall);
674
675 } else {
676 deallocFuncBodyStmts.Add(
677 Comment("promise_type is trivially destructible, no dtor required."sv));
678 }
679#endif
680
681 // This code isn't really there but it is the easiest and cleanest way to visualize the destruction of all
682 // member in the frame. The deallocation function:
683 // https://devblogs.microsoft.com/oldnewthing/20210331-00/?p=105028
685 mOutputFormatHelper.AppendCommentNewLine("This function invoked by coroutine_handle<>::destroy()"sv);
686
687 StmtsContainer deallocFuncBodyStmts{Comment("destroy all variables with dtors"sv)};
688
689 auto* dtorFuncDecl =
690 Function(StrCat("~"sv, GetName(*mASTData.mFrameType)), VoidTy(), {{CORO_FRAME_NAME, GetFramePointerType()}});
691 auto* deallocPromise = AccessMember(mASTData.mFrameAccessDeclRef, dtorFuncDecl);
692 auto* dtorCall = CallMemberFun(deallocPromise, GetFrameType());
693 deallocFuncBodyStmts.Add(dtorCall);
694
695 deallocFuncBodyStmts.Add(Comment("Deallocating the coroutine frame"sv));
696 deallocFuncBodyStmts.Add(
697 Comment("Note: The actual argument to delete is __builtin_coro_frame with the promise as parameter"sv));
698
699 deallocFuncBodyStmts.Add(stmt->getDeallocate());
700
701 SetFunctionBody(deallocFuncDecl, deallocFuncBodyStmts);
702 InsertArg(deallocFuncDecl);
703}
704//-----------------------------------------------------------------------------
705
706void CoroutinesCodeGenerator::InsertArg(const CoroutineBodyStmt* stmt)
707{
708 // insert a made up switch for continuing a resume
709 SwitchStmt* sstmt = Switch(mASTData.mSuspendIndexAccess);
710
711 // insert 0 with break for consistency
712 auto* initialSuspendCase = Case(0, Break());
713 StmtsContainer switchBodyStmts{initialSuspendCase};
714
715 for(const auto& i : NumberIterator{mSuspendsCounter}) {
716 switchBodyStmts.Add(Case(i + 1, Goto(BuildResumeLabelName(i + 1))));
717 }
718
719 auto* switchBody = mkCompoundStmt(switchBodyStmts);
720 sstmt->setBody(switchBody);
721
722 StmtsContainer funcBodyStmts{
723 Comment("Create a switch to get to the correct resume point"sv), sstmt, stmt->getInitSuspendStmt()};
724
725 // insert the init suspend expr
726 mState = eState::InitialSuspend;
727
728 if(mASTData.mThisExprs.size()) {
729 AddField(kwInternalThis, mASTData.mThisExprs.at(0)->getType());
730 }
731
732 mInsertVarDecl = false;
733 mSupressRecordDecls = true;
734
735 for(const auto* c : stmt->getBody()->children()) {
736 funcBodyStmts.Add(c);
737 }
738
739 if(const auto* coReturnVoid = dyn_cast_or_null<CoreturnStmt>(stmt->getFallthroughHandler())) {
740 funcBodyStmts.Add(coReturnVoid);
741 }
742
743 auto* gotoFinalSuspend = Goto(FINAL_SUSPEND_NAME);
744 funcBodyStmts.Add(gotoFinalSuspend);
745
746 auto* body = [&]() -> Stmt* {
747 auto* tryBody = mkCompoundStmt(funcBodyStmts);
748
749 // First open the try-catch block, as we get an error when jumping across such blocks with goto
750 if(const auto* exceptionHandler = stmt->getExceptionHandler()) {
751 // If we encounter an exceptionbefore inital_suspend's await_suspend was called we re-throw the
752 // exception.
753 auto* ifStmt = If(Not(mASTData.mInitialAwaitResumeCalledAccess), Throw());
754
755 StmtsContainer catchBodyStmts{ifStmt, exceptionHandler};
756
757 return Try(tryBody, Catch(catchBodyStmts));
758 }
759
760 return tryBody;
761 }();
762
763 InsertArg(body);
764
766
767 auto* finalSuspendLabel = Label(FINAL_SUSPEND_NAME);
768 InsertArg(finalSuspendLabel);
769 mState = eState::FinalSuspend;
770 InsertArg(stmt->getFinalSuspendStmt());
771
772 // disable prefixing names and types
773 mInsertVarDecl = true;
774}
775//-----------------------------------------------------------------------------
776
777void CoroutinesCodeGenerator::InsertArg(const CXXRecordDecl* stmt)
778{
779 if(not mSupressRecordDecls) {
781 }
782}
783//-----------------------------------------------------------------------------
785// We seem to need this, to peal of some static_casts in a CoroutineSuspendExpr.
786void CoroutinesCodeGenerator::InsertArg(const ImplicitCastExpr* stmt)
787{
788 if(mSupressCasts) {
789 InsertArg(stmt->getSubExpr());
790 } else {
792 }
793}
794//-----------------------------------------------------------------------------
795
796// A special hack to avoid having calls to __builtin_coro_xxx as some of them result in a crash
797// of the compiler and have assumption on the call order and function location.
798void CoroutinesCodeGenerator::InsertArg(const CallExpr* stmt)
799{
800 if(const auto* callee = dyn_cast_or_null<DeclRefExpr>(stmt->getCallee()->IgnoreCasts())) {
801 if(GetPlainName(*callee) == "__builtin_coro_frame"sv) {
803 return;
804
805 } else if(GetPlainName(*callee) == "__builtin_coro_free"sv) {
806 CodeGenerator::InsertArg(stmt->getArg(0));
807 return;
808
809 } else if(GetPlainName(*callee) == "__builtin_coro_size"sv) {
810 CodeGenerator::InsertArg(Sizeof(GetFrameType()));
811 return;
812 }
813 }
814
816}
817//-----------------------------------------------------------------------------
819static std::optional<std::string>
820FindValue(llvm::DenseMap<const Expr*, std::pair<const DeclRefExpr*, std::string>>& map, const Expr* key)
821{
822 if(const auto& s = map.find(key); s != map.end()) {
823 return s->second.second;
824 }
825
826 return {};
827}
828//-----------------------------------------------------------------------------
829
830void CoroutinesCodeGenerator::InsertArg(const OpaqueValueExpr* stmt)
831{
832 const auto* sourceExpr = stmt->getSourceExpr();
833
834 if(const auto& s = FindValue(mOpaqueValues, sourceExpr)) {
835 mOutputFormatHelper.Append(s.value());
836
837 } else {
838 // Needs to be internal because a user can create the same type and it gets put into the stack frame
839 std::string name{BuildSuspendVarName(stmt)};
840
841 // In case of a coroutine-template the same suspension point can occur multiple times. But to know when to add
842 // the _1 we must match the one from each instantiation. The DeclRefExpr is what distinguishes the same
843 // OpaqueValueExpr between multiple instantiations.
844 const auto* dref = FindDeclRef(sourceExpr);
845
846 // The initial_suspend and final_suspend expressions carry the same location info. If we hit such a case,
847 // make up another name.
848 // Below is a std::find_if. However, the same code looks unreadable with std::find_if
849 for(const auto lookupName{StrCat(CORO_FRAME_ACCESS, name)}; const auto& [k, value] : mOpaqueValues) {
850 if(auto [thisDeref, v] = value; (thisDeref == dref) and (v == lookupName)) {
851 name += "_1"sv;
852 break;
853 }
854 }
855
856 const auto accessName{StrCat(CORO_FRAME_ACCESS, name)};
857 mOpaqueValues.insert(std::make_pair(sourceExpr, std::make_pair(dref, accessName)));
858
859 OutputFormatHelper ofm{};
860 CoroutinesCodeGenerator codeGenerator{ofm, mPosBeforeFunc, mFSMName, mSuspendsCount, mASTData};
861
862 auto* promiseField = AddField(name, stmt->getType());
863 BinaryOperator* assignPromiseSuspend =
864 Assign(mASTData.mFrameAccessDeclRef, promiseField, stmt->getSourceExpr());
865
866 codeGenerator.InsertArg(assignPromiseSuspend);
867 ofm.AppendSemiNewLine();
868
869 ofm.SetIndent(mOutputFormatHelper);
870
871 mOutputFormatHelper.InsertAt(mPosBeforeSuspendExpr, ofm);
872 mOutputFormatHelper.Append(accessName);
873 }
874}
875//-----------------------------------------------------------------------------
876
877std::string CoroutinesCodeGenerator::BuildResumeLabelName(int index) const
878{
879 return StrCat(RESUME_LABEL_PREFIX, "_"sv, mFSMName, "_"sv, index);
880}
881//-----------------------------------------------------------------------------
882
883void CoroutinesCodeGenerator::InsertArg(const CoroutineSuspendExpr* stmt)
884{
886 InsertInstantiationPoint(GetGlobalAST().getSourceManager(), stmt->getKeywordLoc(), [&] {
887 if(isa<CoawaitExpr>(stmt)) {
888 return kwCoAwaitSpace;
889 } else {
890 return kwCoYieldSpace;
891 }
892 }());
893
894 mPosBeforeSuspendExpr = mOutputFormatHelper.CurrentPos();
895
896 /// Represents an expression that might suspend coroutine execution;
897 /// either a co_await or co_yield expression.
898 ///
899 /// Evaluation of this expression first evaluates its 'ready' expression. If
900 /// that returns 'false':
901 /// -- execution of the coroutine is suspended
902 /// -- the 'suspend' expression is evaluated
903 /// -- if the 'suspend' expression returns 'false', the coroutine is
904 /// resumed
905 /// -- otherwise, control passes back to the resumer.
906 /// If the coroutine is not suspended, or when it is resumed, the 'resume'
907 /// expression is evaluated, and its result is the result of the overall
908 /// expression.
909
910 // mOutputFormatHelper.AppendNewLine("// __builtin_coro_save() // frame->suspend_index = n");
911
912 // For why, see the implementation of CoroutinesCodeGenerator::InsertArg(const ImplicitCastExpr* stmt)
913 mSupressCasts = true;
914
915 auto* il = Int32(++mSuspendsCount);
916 auto* bop = Assign(mASTData.mSuspendIndexAccess, mASTData.mSuspendIndexField, il);
917
918 // Find out whether the return type is void or bool. In case of bool, we need to insert an if-statement, to
919 // suspend only, if the return value was true.
920 // Technically only void, bool, or std::coroutine_handle<Z> is allowed. [expr.await] p3.7
921 const bool returnsVoid{stmt->getSuspendExpr()->getType()->isVoidType()};
922
923 // XXX: check if getResumeExpr is marked noexcept. Otherwise we need additional expcetion handling?
924 // CGCoroutine.cpp:229
925
926 StmtsContainer bodyStmts{};
927 Expr* initializeInitialAwaitResume = nullptr;
928
929 auto addInitialAwaitSuspendCalled = [&] {
930 bodyStmts.Add(bop);
931
932 if(eState::InitialSuspend == mState) {
933 mState = eState::Body;
934 // https://timsong-cpp.github.io/cppwp/n4861/dcl.fct.def.coroutine#5.3
935 initializeInitialAwaitResume =
936 Assign(mASTData.mFrameAccessDeclRef, mASTData.mInitialAwaitResumeCalledField, Bool(true));
937 bodyStmts.Add(initializeInitialAwaitResume);
938 }
939 };
940
941 if(returnsVoid) {
942 bodyStmts.Add(stmt->getSuspendExpr());
943 addInitialAwaitSuspendCalled();
944 bodyStmts.Add(Return());
945
946 InsertArg(If(Not(stmt->getReadyExpr()), bodyStmts));
947
948 } else {
949 addInitialAwaitSuspendCalled();
950 bodyStmts.Add(Return());
951
952 auto* ifSuspend = If(stmt->getSuspendExpr(), bodyStmts);
953
954 InsertArg(If(Not(stmt->getReadyExpr()), ifSuspend));
955 }
956
957 if(not returnsVoid and initializeInitialAwaitResume) {
958 // At this point we technically haven't called initial suspend
959 InsertArgWithNull(initializeInitialAwaitResume);
960 mOutputFormatHelper.AppendNewLine();
961 }
962
963 auto* suspendLabel = Label(BuildResumeLabelName(mSuspendsCount));
964 InsertArg(suspendLabel);
965
966 if(eState::FinalSuspend == mState) {
967 auto* memExpr = AccessMember(mASTData.mFrameAccessDeclRef, mASTData.mDestroyFnField, true);
968 auto* callCoroFSM = Call(memExpr, {mASTData.mFrameAccessDeclRef});
969 InsertArg(callCoroFSM);
970 return;
971 }
972
973 const auto* resumeExpr = stmt->getResumeExpr();
974
975 if(not resumeExpr->getType()->isVoidType()) {
976 const auto* sourceExpr = stmt->getOpaqueValue()->getSourceExpr();
977
978 if(const auto& s = FindValue(mOpaqueValues, sourceExpr)) {
979 const auto fieldName{StrCat(std::string_view{s.value()}.substr(CORO_FRAME_ACCESS.size()), "_res"sv)};
980 mOutputFormatHelper.Append(CORO_FRAME_ACCESS, fieldName, hlpAssing);
981
982 AddField(fieldName, resumeExpr->getType());
983 }
984 }
985
986 InsertArg(resumeExpr);
987}
988//-----------------------------------------------------------------------------
989
990void CoroutinesCodeGenerator::InsertArg(const CoreturnStmt* stmt)
991{
992 InsertInstantiationPoint(GetGlobalAST().getSourceManager(), stmt->getKeywordLoc(), kwCoReturnSpace);
993
994 if(stmt->getPromiseCall()) {
995 InsertArg(stmt->getPromiseCall());
996
997 if(stmt->isImplicit()) {
998 mOutputFormatHelper.AppendComment("implicit"sv);
999 }
1000 }
1001}
1002//-----------------------------------------------------------------------------
1003
1004void CoroutinesCodeGenerator::InsertArgWithNull(const Stmt* stmt)
1005{
1006 InsertArg(stmt);
1007 InsertArg(mkNullStmt());
1008}
1009//-----------------------------------------------------------------------------
1010
1011} // namespace clang::insights
const ASTContext & GetGlobalAST()
Get access to the ASTContext.
Definition Insights.cpp:81
constexpr std::string_view hlpDestroyFn
constexpr std::string_view kwCoReturnSpace
constexpr std::string_view kwInternalThis
constexpr std::string_view hlpResumeFn
constexpr std::string_view hlpAssing
#define RETURN_IF(cond)
! A helper inspired by https://github.com/Microsoft/wil/wiki/Error-handling-helpers
virtual void InsertArg(const Decl *stmt)
void InsertInstantiationPoint(const SourceManager &sm, const SourceLocation &instLoc, std::string_view text={})
Inserts the instantiation point of a template.
void InsertTemplateArg(const TemplateArgument &arg)
OutputFormatHelper & mOutputFormatHelper
A special container which creates either a CodeGenerator or a CfrontCodeGenerator depending on the co...
Find a SuspendsExpr's in a coroutine body statement for early transformation.
CoroutineASTTransformer(CoroutineASTData &coroutineASTData, size_t &suspendsCounter, Stmt *stmt, llvm::DenseMap< VarDecl *, MemberExpr * > varNamePrefix, Stmt *prev=nullptr)
A special generator for coroutines. It is only activated, if -show-coroutines-transformation is given...
void InsertArg(const ImplicitCastExpr *stmt) override
void InsertCoroutine(const FunctionDecl &fd, const CoroutineBodyStmt *body)
void OpenScope()
Open a scope by inserting a '{' followed by an indented newline.
void AppendNewLine(const char c)
Same as Append but adds a newline after the last argument.
void Append(const char c)
Append a single character.
void CloseScope(const NoNewLineBefore newLineBefore=NoNewLineBefore::No)
Close a scope by inserting a '}'.
void AppendCommentNewLine(const std::string_view &arg)
void AppendSemiNewLine()
Append a semicolon and a newline.
void InsertAt(const size_t atPos, std::string_view data)
Insert a string before the position atPos.
BinaryOperator * Equal(Expr *var, Expr *assignExpr)
UnaryOperator * AddrOf(const Expr *stmt)
CXXReinterpretCastExpr * ReinterpretCast(QualType toType, const Expr *toExpr, bool makePointer)
CXXRecordDecl * Struct(std::string_view name)
FieldDecl * mkFieldDecl(DeclContext *dc, std::string_view name, QualType type)
DeclRefExpr * mkDeclRefExpr(const ValueDecl *vd)
CXXNewExpr * New(ArrayRef< Expr * > placementArgs, const Expr *expr, QualType t)
std::vector< std::pair< std::string_view, QualType > > params_vector
Definition ASTHelpers.h:27
CallExpr * Call(const FunctionDecl *fd, ArrayRef< Expr * > params)
UnaryOperator * Ref(const Expr *e)
VarDecl * Variable(std::string_view name, QualType type, DeclContext *dc)
MemberExpr * AccessMember(const Expr *expr, const ValueDecl *vd, bool isArrow)
DeclRefExpr * mkVarDeclRefExpr(std::string_view name, QualType type)
ReturnStmt * Return(Expr *stmt)
CaseStmt * Case(int value, Stmt *stmt)
Stmt * Comment(std::string_view comment)
IfStmt * If(const Expr *condition, ArrayRef< Stmt * > bodyStmts)
GotoStmt * Goto(std::string_view labelName)
CXXCatchStmt * Catch(ArrayRef< Stmt * > body)
CXXThrowExpr * Throw(const Expr *expr)
void ReplaceNode(Stmt *parent, Stmt *oldNode, Stmt *newNode)
UnaryExprOrTypeTraitExpr * Sizeof(QualType toType)
LabelStmt * Label(std::string_view name)
UnaryOperator * Not(const Expr *stmt)
CXXBoolLiteralExpr * Bool(bool b)
FunctionDecl * Function(std::string_view name, QualType returnType, const params_vector &parameters)
CompoundStmt * mkCompoundStmt(ArrayRef< Stmt * > bodyStmts, SourceLocation beginLoc, SourceLocation endLoc)
SwitchStmt * Switch(Expr *stmt)
IntegerLiteral * Int32(uint64_t value)
ParmVarDecl * Parameter(const FunctionDecl *fd, std::string_view name, QualType type)
UnaryOperator * Dref(const Expr *stmt)
QualType Ptr(QualType srcType)
CXXStaticCastExpr * StaticCast(QualType toType, const Expr *toExpr, bool makePointer)
CXXMemberCallExpr * CallMemberFun(Expr *memExpr, QualType retType)
BinaryOperator * Assign(const VarDecl *var, Expr *assignExpr)
CXXTryStmt * Try(const Stmt *tryBody, CXXCatchStmt *catchAllBody)
const std::string SUSPEND_INDEX_NAME
const std::string RESUME_LABEL_PREFIX
bool Contains(const std::string_view source, const std::string_view search)
const DeclRefExpr * FindDeclRef(const Stmt *stmt)
Go deep in a Stmt if necessary and look to all childs for a DeclRefExpr.
std::string GetPlainName(const DeclRefExpr &DRE)
static auto * CreateCoroFunctionDecl(std::string funcName, QualType type)
static void SetFunctionBody(FunctionDecl *fd, StmtsContainer &bodyStmts)
constexpr std::string_view CORO_FRAME_NAME
std::string MakeLineColumnName(const SourceManager &sm, const SourceLocation &loc, const std::string_view &prefix)
std::string GetName(const NamedDecl &nd, const QualifiedName qualifiedName)
void ReplaceAll(std::string &str, std::string_view from, std::string_view to)
std::string BuildInternalVarName(const std::string_view &varName)
const std::string CORO_FRAME_ACCESS_THIS
std::string BuildTemplateParamObjectName(std::string name)
const std::string FINAL_SUSPEND_NAME
const std::string INITIAL_AWAIT_SUSPEND_CALLED_NAME
static std::string BuildSuspendVarName(const OpaqueValueExpr *stmt)
static std::optional< std::string > FindValue(llvm::DenseMap< const Expr *, std::pair< const DeclRefExpr *, std::string > > &map, const Expr *key)
void EnableGlobalInsert(GlobalInserts idx)
Definition Insights.cpp:106
const std::string CORO_FRAME_ACCESS
static FieldDecl * AddField(CoroutineASTData &astData, std::string_view name, QualType type)
std::string StrCat(const auto &... args)
std::vector< const CXXThisExpr * > mThisExprs
! A helper type to have a container for ArrayRef
Definition ASTHelpers.h:64