Merge pull request #21410 from Shourya742/2026-01-06-improve-bidirectional-cancellation
Make proc-macro bidirectional calls cancellation safe
This commit is contained in:
commit
f4292527f9
4 changed files with 114 additions and 36 deletions
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
use std::{
|
||||
io::{self, BufRead, Write},
|
||||
panic::{AssertUnwindSafe, catch_unwind},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
|
|
@ -55,9 +56,19 @@ pub fn run_conversation(
|
|||
return Ok(BidirectionalMessage::Response(response));
|
||||
}
|
||||
BidirectionalMessage::SubRequest(sr) => {
|
||||
let resp = callback(sr)?;
|
||||
let reply = BidirectionalMessage::SubResponse(resp);
|
||||
let encoded = postcard::encode(&reply).map_err(wrap_encode)?;
|
||||
// TODO: Avoid `AssertUnwindSafe` by making the callback `UnwindSafe` once `ExpandDatabase`
|
||||
// becomes unwind-safe (currently blocked by `parking_lot::RwLock` in the VFS).
|
||||
let resp = match catch_unwind(AssertUnwindSafe(|| callback(sr))) {
|
||||
Ok(Ok(resp)) => BidirectionalMessage::SubResponse(resp),
|
||||
Ok(Err(err)) => BidirectionalMessage::SubResponse(SubResponse::Cancel {
|
||||
reason: err.to_string(),
|
||||
}),
|
||||
Err(_) => BidirectionalMessage::SubResponse(SubResponse::Cancel {
|
||||
reason: "callback panicked or was cancelled".into(),
|
||||
}),
|
||||
};
|
||||
|
||||
let encoded = postcard::encode(&resp).map_err(wrap_encode)?;
|
||||
postcard::write(writer, &encoded)
|
||||
.map_err(wrap_io("failed to write sub-response"))?;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -42,6 +42,9 @@ pub enum SubResponse {
|
|||
ByteRangeResult {
|
||||
range: Range<usize>,
|
||||
},
|
||||
Cancel {
|
||||
reason: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ use proc_macro_api::{
|
|||
ProtocolFormat, bidirectional_protocol::msg as bidirectional, legacy_protocol::msg as legacy,
|
||||
version::CURRENT_API_VERSION,
|
||||
};
|
||||
use std::panic::{panic_any, resume_unwind};
|
||||
use std::{
|
||||
io::{self, BufRead, Write},
|
||||
ops::Range,
|
||||
|
|
@ -10,7 +11,7 @@ use std::{
|
|||
|
||||
use legacy::Message;
|
||||
|
||||
use proc_macro_srv::{EnvSnapshot, SpanId};
|
||||
use proc_macro_srv::{EnvSnapshot, ProcMacroClientError, ProcMacroPanicMarker, SpanId};
|
||||
|
||||
struct SpanTrans;
|
||||
|
||||
|
|
@ -172,16 +173,43 @@ impl<'a> ProcMacroClientHandle<'a> {
|
|||
fn roundtrip(
|
||||
&mut self,
|
||||
req: bidirectional::SubRequest,
|
||||
) -> Option<bidirectional::BidirectionalMessage> {
|
||||
) -> Result<bidirectional::SubResponse, ProcMacroClientError> {
|
||||
let msg = bidirectional::BidirectionalMessage::SubRequest(req);
|
||||
|
||||
if msg.write(&mut *self.stdout).is_err() {
|
||||
return None;
|
||||
}
|
||||
msg.write(&mut *self.stdout).map_err(ProcMacroClientError::Io)?;
|
||||
|
||||
match bidirectional::BidirectionalMessage::read(&mut *self.stdin, self.buf) {
|
||||
Ok(Some(msg)) => Some(msg),
|
||||
_ => None,
|
||||
let msg = bidirectional::BidirectionalMessage::read(&mut *self.stdin, self.buf)
|
||||
.map_err(ProcMacroClientError::Io)?
|
||||
.ok_or(ProcMacroClientError::Eof)?;
|
||||
|
||||
match msg {
|
||||
bidirectional::BidirectionalMessage::SubResponse(resp) => match resp {
|
||||
bidirectional::SubResponse::Cancel { reason } => {
|
||||
Err(ProcMacroClientError::Cancelled { reason })
|
||||
}
|
||||
other => Ok(other),
|
||||
},
|
||||
other => {
|
||||
Err(ProcMacroClientError::Protocol(format!("expected SubResponse, got {other:?}")))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_failure(failure: Result<bidirectional::SubResponse, ProcMacroClientError>) -> ! {
|
||||
match failure {
|
||||
Err(ProcMacroClientError::Cancelled { reason }) => {
|
||||
resume_unwind(Box::new(ProcMacroPanicMarker::Cancelled { reason }));
|
||||
}
|
||||
Err(err) => {
|
||||
panic_any(ProcMacroPanicMarker::Internal {
|
||||
reason: format!("proc-macro IPC error: {err:?}"),
|
||||
});
|
||||
}
|
||||
Ok(other) => {
|
||||
panic_any(ProcMacroPanicMarker::Internal {
|
||||
reason: format!("unexpected SubResponse {other:?}"),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -189,10 +217,8 @@ impl<'a> ProcMacroClientHandle<'a> {
|
|||
impl proc_macro_srv::ProcMacroClientInterface for ProcMacroClientHandle<'_> {
|
||||
fn file(&mut self, file_id: proc_macro_srv::span::FileId) -> String {
|
||||
match self.roundtrip(bidirectional::SubRequest::FilePath { file_id: file_id.index() }) {
|
||||
Some(bidirectional::BidirectionalMessage::SubResponse(
|
||||
bidirectional::SubResponse::FilePathResult { name },
|
||||
)) => name,
|
||||
_ => String::new(),
|
||||
Ok(bidirectional::SubResponse::FilePathResult { name }) => name,
|
||||
other => handle_failure(other),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -206,20 +232,16 @@ impl proc_macro_srv::ProcMacroClientInterface for ProcMacroClientHandle<'_> {
|
|||
start: range.start().into(),
|
||||
end: range.end().into(),
|
||||
}) {
|
||||
Some(bidirectional::BidirectionalMessage::SubResponse(
|
||||
bidirectional::SubResponse::SourceTextResult { text },
|
||||
)) => text,
|
||||
_ => None,
|
||||
Ok(bidirectional::SubResponse::SourceTextResult { text }) => text,
|
||||
other => handle_failure(other),
|
||||
}
|
||||
}
|
||||
|
||||
fn local_file(&mut self, file_id: proc_macro_srv::span::FileId) -> Option<String> {
|
||||
match self.roundtrip(bidirectional::SubRequest::LocalFilePath { file_id: file_id.index() })
|
||||
{
|
||||
Some(bidirectional::BidirectionalMessage::SubResponse(
|
||||
bidirectional::SubResponse::LocalFilePathResult { name },
|
||||
)) => name,
|
||||
_ => None,
|
||||
Ok(bidirectional::SubResponse::LocalFilePathResult { name }) => name,
|
||||
other => handle_failure(other),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -230,10 +252,10 @@ impl proc_macro_srv::ProcMacroClientInterface for ProcMacroClientHandle<'_> {
|
|||
ast_id: anchor.ast_id.into_raw(),
|
||||
offset: range.start().into(),
|
||||
}) {
|
||||
Some(bidirectional::BidirectionalMessage::SubResponse(
|
||||
bidirectional::SubResponse::LineColumnResult { line, column },
|
||||
)) => Some((line, column)),
|
||||
_ => None,
|
||||
Ok(bidirectional::SubResponse::LineColumnResult { line, column }) => {
|
||||
Some((line, column))
|
||||
}
|
||||
other => handle_failure(other),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -247,10 +269,8 @@ impl proc_macro_srv::ProcMacroClientInterface for ProcMacroClientHandle<'_> {
|
|||
start: range.start().into(),
|
||||
end: range.end().into(),
|
||||
}) {
|
||||
Some(bidirectional::BidirectionalMessage::SubResponse(
|
||||
bidirectional::SubResponse::ByteRangeResult { range },
|
||||
)) => range,
|
||||
_ => Range { start: range.start().into(), end: range.end().into() },
|
||||
Ok(bidirectional::SubResponse::ByteRangeResult { range }) => range,
|
||||
other => handle_failure(other),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -96,6 +96,20 @@ impl<'env> ProcMacroSrv<'env> {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ProcMacroClientError {
|
||||
Cancelled { reason: String },
|
||||
Io(std::io::Error),
|
||||
Protocol(String),
|
||||
Eof,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ProcMacroPanicMarker {
|
||||
Cancelled { reason: String },
|
||||
Internal { reason: String },
|
||||
}
|
||||
|
||||
pub type ProcMacroClientHandle<'a> = &'a mut (dyn ProcMacroClientInterface + Sync + Send);
|
||||
|
||||
pub trait ProcMacroClientInterface {
|
||||
|
|
@ -110,6 +124,22 @@ pub trait ProcMacroClientInterface {
|
|||
|
||||
const EXPANDER_STACK_SIZE: usize = 8 * 1024 * 1024;
|
||||
|
||||
pub enum ExpandError {
|
||||
Panic(PanicMessage),
|
||||
Cancelled { reason: Option<String> },
|
||||
Internal { reason: Option<String> },
|
||||
}
|
||||
|
||||
impl ExpandError {
|
||||
pub fn into_string(self) -> Option<String> {
|
||||
match self {
|
||||
ExpandError::Panic(panic_message) => panic_message.into_string(),
|
||||
ExpandError::Cancelled { reason } => reason,
|
||||
ExpandError::Internal { reason } => reason,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ProcMacroSrv<'_> {
|
||||
pub fn expand<S: ProcMacroSrvSpan>(
|
||||
&self,
|
||||
|
|
@ -123,10 +153,10 @@ impl ProcMacroSrv<'_> {
|
|||
call_site: S,
|
||||
mixed_site: S,
|
||||
callback: Option<ProcMacroClientHandle<'_>>,
|
||||
) -> Result<token_stream::TokenStream<S>, PanicMessage> {
|
||||
) -> Result<token_stream::TokenStream<S>, ExpandError> {
|
||||
let snapped_env = self.env;
|
||||
let expander = self.expander(lib.as_ref()).map_err(|err| PanicMessage {
|
||||
message: Some(format!("failed to load macro: {err}")),
|
||||
let expander = self.expander(lib.as_ref()).map_err(|err| ExpandError::Internal {
|
||||
reason: Some(format!("failed to load macro: {err}")),
|
||||
})?;
|
||||
|
||||
let prev_env = EnvChange::apply(snapped_env, env, current_dir.as_ref().map(<_>::as_ref));
|
||||
|
|
@ -144,8 +174,22 @@ impl ProcMacroSrv<'_> {
|
|||
)
|
||||
});
|
||||
match thread.unwrap().join() {
|
||||
Ok(res) => res,
|
||||
Err(e) => std::panic::resume_unwind(e),
|
||||
Ok(res) => res.map_err(ExpandError::Panic),
|
||||
|
||||
Err(payload) => {
|
||||
if let Some(marker) = payload.downcast_ref::<ProcMacroPanicMarker>() {
|
||||
return match marker {
|
||||
ProcMacroPanicMarker::Cancelled { reason } => {
|
||||
Err(ExpandError::Cancelled { reason: Some(reason.clone()) })
|
||||
}
|
||||
ProcMacroPanicMarker::Internal { reason } => {
|
||||
Err(ExpandError::Internal { reason: Some(reason.clone()) })
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
std::panic::resume_unwind(payload)
|
||||
}
|
||||
}
|
||||
});
|
||||
prev_env.rollback();
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue