diff --git a/src/tools/rust-analyzer/crates/proc-macro-api/src/bidirectional_protocol.rs b/src/tools/rust-analyzer/crates/proc-macro-api/src/bidirectional_protocol.rs index 8311df23d718..ba59cb219b9a 100644 --- a/src/tools/rust-analyzer/crates/proc-macro-api/src/bidirectional_protocol.rs +++ b/src/tools/rust-analyzer/crates/proc-macro-api/src/bidirectional_protocol.rs @@ -2,6 +2,7 @@ use std::{ io::{self, BufRead, Write}, + panic::{AssertUnwindSafe, catch_unwind}, sync::Arc, }; @@ -55,9 +56,19 @@ pub fn run_conversation( return Ok(BidirectionalMessage::Response(response)); } BidirectionalMessage::SubRequest(sr) => { - let resp = callback(sr)?; - let reply = BidirectionalMessage::SubResponse(resp); - let encoded = postcard::encode(&reply).map_err(wrap_encode)?; + // TODO: Avoid `AssertUnwindSafe` by making the callback `UnwindSafe` once `ExpandDatabase` + // becomes unwind-safe (currently blocked by `parking_lot::RwLock` in the VFS). + let resp = match catch_unwind(AssertUnwindSafe(|| callback(sr))) { + Ok(Ok(resp)) => BidirectionalMessage::SubResponse(resp), + Ok(Err(err)) => BidirectionalMessage::SubResponse(SubResponse::Cancel { + reason: err.to_string(), + }), + Err(_) => BidirectionalMessage::SubResponse(SubResponse::Cancel { + reason: "callback panicked or was cancelled".into(), + }), + }; + + let encoded = postcard::encode(&resp).map_err(wrap_encode)?; postcard::write(writer, &encoded) .map_err(wrap_io("failed to write sub-response"))?; } diff --git a/src/tools/rust-analyzer/crates/proc-macro-api/src/bidirectional_protocol/msg.rs b/src/tools/rust-analyzer/crates/proc-macro-api/src/bidirectional_protocol/msg.rs index 1df0c68379a5..3f0422dc5bc8 100644 --- a/src/tools/rust-analyzer/crates/proc-macro-api/src/bidirectional_protocol/msg.rs +++ b/src/tools/rust-analyzer/crates/proc-macro-api/src/bidirectional_protocol/msg.rs @@ -42,6 +42,9 @@ pub enum SubResponse { ByteRangeResult { range: Range, }, + Cancel { + reason: String, + }, } #[derive(Debug, Serialize, Deserialize)] diff --git a/src/tools/rust-analyzer/crates/proc-macro-srv-cli/src/main_loop.rs b/src/tools/rust-analyzer/crates/proc-macro-srv-cli/src/main_loop.rs index 758629fd1fd6..9be3199a3836 100644 --- a/src/tools/rust-analyzer/crates/proc-macro-srv-cli/src/main_loop.rs +++ b/src/tools/rust-analyzer/crates/proc-macro-srv-cli/src/main_loop.rs @@ -3,6 +3,7 @@ use proc_macro_api::{ ProtocolFormat, bidirectional_protocol::msg as bidirectional, legacy_protocol::msg as legacy, version::CURRENT_API_VERSION, }; +use std::panic::{panic_any, resume_unwind}; use std::{ io::{self, BufRead, Write}, ops::Range, @@ -10,7 +11,7 @@ use std::{ use legacy::Message; -use proc_macro_srv::{EnvSnapshot, SpanId}; +use proc_macro_srv::{EnvSnapshot, ProcMacroClientError, ProcMacroPanicMarker, SpanId}; struct SpanTrans; @@ -172,16 +173,43 @@ impl<'a> ProcMacroClientHandle<'a> { fn roundtrip( &mut self, req: bidirectional::SubRequest, - ) -> Option { + ) -> Result { let msg = bidirectional::BidirectionalMessage::SubRequest(req); - if msg.write(&mut *self.stdout).is_err() { - return None; - } + msg.write(&mut *self.stdout).map_err(ProcMacroClientError::Io)?; - match bidirectional::BidirectionalMessage::read(&mut *self.stdin, self.buf) { - Ok(Some(msg)) => Some(msg), - _ => None, + let msg = bidirectional::BidirectionalMessage::read(&mut *self.stdin, self.buf) + .map_err(ProcMacroClientError::Io)? + .ok_or(ProcMacroClientError::Eof)?; + + match msg { + bidirectional::BidirectionalMessage::SubResponse(resp) => match resp { + bidirectional::SubResponse::Cancel { reason } => { + Err(ProcMacroClientError::Cancelled { reason }) + } + other => Ok(other), + }, + other => { + Err(ProcMacroClientError::Protocol(format!("expected SubResponse, got {other:?}"))) + } + } + } +} + +fn handle_failure(failure: Result) -> ! { + match failure { + Err(ProcMacroClientError::Cancelled { reason }) => { + resume_unwind(Box::new(ProcMacroPanicMarker::Cancelled { reason })); + } + Err(err) => { + panic_any(ProcMacroPanicMarker::Internal { + reason: format!("proc-macro IPC error: {err:?}"), + }); + } + Ok(other) => { + panic_any(ProcMacroPanicMarker::Internal { + reason: format!("unexpected SubResponse {other:?}"), + }); } } } @@ -189,10 +217,8 @@ impl<'a> ProcMacroClientHandle<'a> { impl proc_macro_srv::ProcMacroClientInterface for ProcMacroClientHandle<'_> { fn file(&mut self, file_id: proc_macro_srv::span::FileId) -> String { match self.roundtrip(bidirectional::SubRequest::FilePath { file_id: file_id.index() }) { - Some(bidirectional::BidirectionalMessage::SubResponse( - bidirectional::SubResponse::FilePathResult { name }, - )) => name, - _ => String::new(), + Ok(bidirectional::SubResponse::FilePathResult { name }) => name, + other => handle_failure(other), } } @@ -206,20 +232,16 @@ impl proc_macro_srv::ProcMacroClientInterface for ProcMacroClientHandle<'_> { start: range.start().into(), end: range.end().into(), }) { - Some(bidirectional::BidirectionalMessage::SubResponse( - bidirectional::SubResponse::SourceTextResult { text }, - )) => text, - _ => None, + Ok(bidirectional::SubResponse::SourceTextResult { text }) => text, + other => handle_failure(other), } } fn local_file(&mut self, file_id: proc_macro_srv::span::FileId) -> Option { match self.roundtrip(bidirectional::SubRequest::LocalFilePath { file_id: file_id.index() }) { - Some(bidirectional::BidirectionalMessage::SubResponse( - bidirectional::SubResponse::LocalFilePathResult { name }, - )) => name, - _ => None, + Ok(bidirectional::SubResponse::LocalFilePathResult { name }) => name, + other => handle_failure(other), } } @@ -230,10 +252,10 @@ impl proc_macro_srv::ProcMacroClientInterface for ProcMacroClientHandle<'_> { ast_id: anchor.ast_id.into_raw(), offset: range.start().into(), }) { - Some(bidirectional::BidirectionalMessage::SubResponse( - bidirectional::SubResponse::LineColumnResult { line, column }, - )) => Some((line, column)), - _ => None, + Ok(bidirectional::SubResponse::LineColumnResult { line, column }) => { + Some((line, column)) + } + other => handle_failure(other), } } @@ -247,10 +269,8 @@ impl proc_macro_srv::ProcMacroClientInterface for ProcMacroClientHandle<'_> { start: range.start().into(), end: range.end().into(), }) { - Some(bidirectional::BidirectionalMessage::SubResponse( - bidirectional::SubResponse::ByteRangeResult { range }, - )) => range, - _ => Range { start: range.start().into(), end: range.end().into() }, + Ok(bidirectional::SubResponse::ByteRangeResult { range }) => range, + other => handle_failure(other), } } } diff --git a/src/tools/rust-analyzer/crates/proc-macro-srv/src/lib.rs b/src/tools/rust-analyzer/crates/proc-macro-srv/src/lib.rs index e04f744ae2b0..c548dc620ad1 100644 --- a/src/tools/rust-analyzer/crates/proc-macro-srv/src/lib.rs +++ b/src/tools/rust-analyzer/crates/proc-macro-srv/src/lib.rs @@ -96,6 +96,20 @@ impl<'env> ProcMacroSrv<'env> { } } +#[derive(Debug)] +pub enum ProcMacroClientError { + Cancelled { reason: String }, + Io(std::io::Error), + Protocol(String), + Eof, +} + +#[derive(Debug)] +pub enum ProcMacroPanicMarker { + Cancelled { reason: String }, + Internal { reason: String }, +} + pub type ProcMacroClientHandle<'a> = &'a mut (dyn ProcMacroClientInterface + Sync + Send); pub trait ProcMacroClientInterface { @@ -110,6 +124,22 @@ pub trait ProcMacroClientInterface { const EXPANDER_STACK_SIZE: usize = 8 * 1024 * 1024; +pub enum ExpandError { + Panic(PanicMessage), + Cancelled { reason: Option }, + Internal { reason: Option }, +} + +impl ExpandError { + pub fn into_string(self) -> Option { + match self { + ExpandError::Panic(panic_message) => panic_message.into_string(), + ExpandError::Cancelled { reason } => reason, + ExpandError::Internal { reason } => reason, + } + } +} + impl ProcMacroSrv<'_> { pub fn expand( &self, @@ -123,10 +153,10 @@ impl ProcMacroSrv<'_> { call_site: S, mixed_site: S, callback: Option>, - ) -> Result, PanicMessage> { + ) -> Result, ExpandError> { let snapped_env = self.env; - let expander = self.expander(lib.as_ref()).map_err(|err| PanicMessage { - message: Some(format!("failed to load macro: {err}")), + let expander = self.expander(lib.as_ref()).map_err(|err| ExpandError::Internal { + reason: Some(format!("failed to load macro: {err}")), })?; let prev_env = EnvChange::apply(snapped_env, env, current_dir.as_ref().map(<_>::as_ref)); @@ -144,8 +174,22 @@ impl ProcMacroSrv<'_> { ) }); match thread.unwrap().join() { - Ok(res) => res, - Err(e) => std::panic::resume_unwind(e), + Ok(res) => res.map_err(ExpandError::Panic), + + Err(payload) => { + if let Some(marker) = payload.downcast_ref::() { + return match marker { + ProcMacroPanicMarker::Cancelled { reason } => { + Err(ExpandError::Cancelled { reason: Some(reason.clone()) }) + } + ProcMacroPanicMarker::Internal { reason } => { + Err(ExpandError::Internal { reason: Some(reason.clone()) }) + } + }; + } + + std::panic::resume_unwind(payload) + } } }); prev_env.rollback();