Merge pull request #21410 from Shourya742/2026-01-06-improve-bidirectional-cancellation

Make proc-macro bidirectional calls cancellation safe
This commit is contained in:
Lukas Wirth 2026-02-01 12:06:03 +00:00 committed by GitHub
commit f4292527f9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 114 additions and 36 deletions

View file

@ -2,6 +2,7 @@
use std::{
io::{self, BufRead, Write},
panic::{AssertUnwindSafe, catch_unwind},
sync::Arc,
};
@ -55,9 +56,19 @@ pub fn run_conversation(
return Ok(BidirectionalMessage::Response(response));
}
BidirectionalMessage::SubRequest(sr) => {
let resp = callback(sr)?;
let reply = BidirectionalMessage::SubResponse(resp);
let encoded = postcard::encode(&reply).map_err(wrap_encode)?;
// TODO: Avoid `AssertUnwindSafe` by making the callback `UnwindSafe` once `ExpandDatabase`
// becomes unwind-safe (currently blocked by `parking_lot::RwLock` in the VFS).
let resp = match catch_unwind(AssertUnwindSafe(|| callback(sr))) {
Ok(Ok(resp)) => BidirectionalMessage::SubResponse(resp),
Ok(Err(err)) => BidirectionalMessage::SubResponse(SubResponse::Cancel {
reason: err.to_string(),
}),
Err(_) => BidirectionalMessage::SubResponse(SubResponse::Cancel {
reason: "callback panicked or was cancelled".into(),
}),
};
let encoded = postcard::encode(&resp).map_err(wrap_encode)?;
postcard::write(writer, &encoded)
.map_err(wrap_io("failed to write sub-response"))?;
}

View file

@ -42,6 +42,9 @@ pub enum SubResponse {
ByteRangeResult {
range: Range<usize>,
},
Cancel {
reason: String,
},
}
#[derive(Debug, Serialize, Deserialize)]

View file

@ -3,6 +3,7 @@ use proc_macro_api::{
ProtocolFormat, bidirectional_protocol::msg as bidirectional, legacy_protocol::msg as legacy,
version::CURRENT_API_VERSION,
};
use std::panic::{panic_any, resume_unwind};
use std::{
io::{self, BufRead, Write},
ops::Range,
@ -10,7 +11,7 @@ use std::{
use legacy::Message;
use proc_macro_srv::{EnvSnapshot, SpanId};
use proc_macro_srv::{EnvSnapshot, ProcMacroClientError, ProcMacroPanicMarker, SpanId};
struct SpanTrans;
@ -172,16 +173,43 @@ impl<'a> ProcMacroClientHandle<'a> {
fn roundtrip(
&mut self,
req: bidirectional::SubRequest,
) -> Option<bidirectional::BidirectionalMessage> {
) -> Result<bidirectional::SubResponse, ProcMacroClientError> {
let msg = bidirectional::BidirectionalMessage::SubRequest(req);
if msg.write(&mut *self.stdout).is_err() {
return None;
}
msg.write(&mut *self.stdout).map_err(ProcMacroClientError::Io)?;
match bidirectional::BidirectionalMessage::read(&mut *self.stdin, self.buf) {
Ok(Some(msg)) => Some(msg),
_ => None,
let msg = bidirectional::BidirectionalMessage::read(&mut *self.stdin, self.buf)
.map_err(ProcMacroClientError::Io)?
.ok_or(ProcMacroClientError::Eof)?;
match msg {
bidirectional::BidirectionalMessage::SubResponse(resp) => match resp {
bidirectional::SubResponse::Cancel { reason } => {
Err(ProcMacroClientError::Cancelled { reason })
}
other => Ok(other),
},
other => {
Err(ProcMacroClientError::Protocol(format!("expected SubResponse, got {other:?}")))
}
}
}
}
fn handle_failure(failure: Result<bidirectional::SubResponse, ProcMacroClientError>) -> ! {
match failure {
Err(ProcMacroClientError::Cancelled { reason }) => {
resume_unwind(Box::new(ProcMacroPanicMarker::Cancelled { reason }));
}
Err(err) => {
panic_any(ProcMacroPanicMarker::Internal {
reason: format!("proc-macro IPC error: {err:?}"),
});
}
Ok(other) => {
panic_any(ProcMacroPanicMarker::Internal {
reason: format!("unexpected SubResponse {other:?}"),
});
}
}
}
@ -189,10 +217,8 @@ impl<'a> ProcMacroClientHandle<'a> {
impl proc_macro_srv::ProcMacroClientInterface for ProcMacroClientHandle<'_> {
fn file(&mut self, file_id: proc_macro_srv::span::FileId) -> String {
match self.roundtrip(bidirectional::SubRequest::FilePath { file_id: file_id.index() }) {
Some(bidirectional::BidirectionalMessage::SubResponse(
bidirectional::SubResponse::FilePathResult { name },
)) => name,
_ => String::new(),
Ok(bidirectional::SubResponse::FilePathResult { name }) => name,
other => handle_failure(other),
}
}
@ -206,20 +232,16 @@ impl proc_macro_srv::ProcMacroClientInterface for ProcMacroClientHandle<'_> {
start: range.start().into(),
end: range.end().into(),
}) {
Some(bidirectional::BidirectionalMessage::SubResponse(
bidirectional::SubResponse::SourceTextResult { text },
)) => text,
_ => None,
Ok(bidirectional::SubResponse::SourceTextResult { text }) => text,
other => handle_failure(other),
}
}
fn local_file(&mut self, file_id: proc_macro_srv::span::FileId) -> Option<String> {
match self.roundtrip(bidirectional::SubRequest::LocalFilePath { file_id: file_id.index() })
{
Some(bidirectional::BidirectionalMessage::SubResponse(
bidirectional::SubResponse::LocalFilePathResult { name },
)) => name,
_ => None,
Ok(bidirectional::SubResponse::LocalFilePathResult { name }) => name,
other => handle_failure(other),
}
}
@ -230,10 +252,10 @@ impl proc_macro_srv::ProcMacroClientInterface for ProcMacroClientHandle<'_> {
ast_id: anchor.ast_id.into_raw(),
offset: range.start().into(),
}) {
Some(bidirectional::BidirectionalMessage::SubResponse(
bidirectional::SubResponse::LineColumnResult { line, column },
)) => Some((line, column)),
_ => None,
Ok(bidirectional::SubResponse::LineColumnResult { line, column }) => {
Some((line, column))
}
other => handle_failure(other),
}
}
@ -247,10 +269,8 @@ impl proc_macro_srv::ProcMacroClientInterface for ProcMacroClientHandle<'_> {
start: range.start().into(),
end: range.end().into(),
}) {
Some(bidirectional::BidirectionalMessage::SubResponse(
bidirectional::SubResponse::ByteRangeResult { range },
)) => range,
_ => Range { start: range.start().into(), end: range.end().into() },
Ok(bidirectional::SubResponse::ByteRangeResult { range }) => range,
other => handle_failure(other),
}
}
}

View file

@ -96,6 +96,20 @@ impl<'env> ProcMacroSrv<'env> {
}
}
#[derive(Debug)]
pub enum ProcMacroClientError {
Cancelled { reason: String },
Io(std::io::Error),
Protocol(String),
Eof,
}
#[derive(Debug)]
pub enum ProcMacroPanicMarker {
Cancelled { reason: String },
Internal { reason: String },
}
pub type ProcMacroClientHandle<'a> = &'a mut (dyn ProcMacroClientInterface + Sync + Send);
pub trait ProcMacroClientInterface {
@ -110,6 +124,22 @@ pub trait ProcMacroClientInterface {
const EXPANDER_STACK_SIZE: usize = 8 * 1024 * 1024;
pub enum ExpandError {
Panic(PanicMessage),
Cancelled { reason: Option<String> },
Internal { reason: Option<String> },
}
impl ExpandError {
pub fn into_string(self) -> Option<String> {
match self {
ExpandError::Panic(panic_message) => panic_message.into_string(),
ExpandError::Cancelled { reason } => reason,
ExpandError::Internal { reason } => reason,
}
}
}
impl ProcMacroSrv<'_> {
pub fn expand<S: ProcMacroSrvSpan>(
&self,
@ -123,10 +153,10 @@ impl ProcMacroSrv<'_> {
call_site: S,
mixed_site: S,
callback: Option<ProcMacroClientHandle<'_>>,
) -> Result<token_stream::TokenStream<S>, PanicMessage> {
) -> Result<token_stream::TokenStream<S>, ExpandError> {
let snapped_env = self.env;
let expander = self.expander(lib.as_ref()).map_err(|err| PanicMessage {
message: Some(format!("failed to load macro: {err}")),
let expander = self.expander(lib.as_ref()).map_err(|err| ExpandError::Internal {
reason: Some(format!("failed to load macro: {err}")),
})?;
let prev_env = EnvChange::apply(snapped_env, env, current_dir.as_ref().map(<_>::as_ref));
@ -144,8 +174,22 @@ impl ProcMacroSrv<'_> {
)
});
match thread.unwrap().join() {
Ok(res) => res,
Err(e) => std::panic::resume_unwind(e),
Ok(res) => res.map_err(ExpandError::Panic),
Err(payload) => {
if let Some(marker) = payload.downcast_ref::<ProcMacroPanicMarker>() {
return match marker {
ProcMacroPanicMarker::Cancelled { reason } => {
Err(ExpandError::Cancelled { reason: Some(reason.clone()) })
}
ProcMacroPanicMarker::Internal { reason } => {
Err(ExpandError::Internal { reason: Some(reason.clone()) })
}
};
}
std::panic::resume_unwind(payload)
}
}
});
prev_env.rollback();