proc_macro_api/
process.rs

1//! Handle process life-time and message passing for proc-macro client
2
3use std::{
4    fmt::Debug,
5    io::{self, BufRead, BufReader, Read, Write},
6    panic::AssertUnwindSafe,
7    process::{Child, ChildStdin, ChildStdout, Command, Stdio},
8    sync::{
9        Arc, Mutex, OnceLock,
10        atomic::{AtomicU32, Ordering},
11    },
12};
13
14use paths::AbsPath;
15use semver::Version;
16use span::Span;
17use stdx::JodChild;
18
19use crate::{
20    ProcMacro, ProcMacroKind, ProtocolFormat, ServerError,
21    bidirectional_protocol::{self, SubCallback, msg::BidirectionalMessage, reject_subrequests},
22    legacy_protocol::{self, SpanMode},
23    version,
24};
25
26/// Represents a process handling proc-macro communication.
27pub(crate) struct ProcMacroServerProcess {
28    /// The state of the proc-macro server process, the protocol is currently strictly sequential
29    /// hence the lock on the state.
30    state: Mutex<ProcessSrvState>,
31    version: u32,
32    protocol: Protocol,
33    /// Populated when the server exits.
34    exited: OnceLock<AssertUnwindSafe<ServerError>>,
35    active: AtomicU32,
36}
37
38impl std::fmt::Debug for ProcMacroServerProcess {
39    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40        f.debug_struct("ProcMacroServerProcess")
41            .field("version", &self.version)
42            .field("protocol", &self.protocol)
43            .field("exited", &self.exited)
44            .finish()
45    }
46}
47
48#[derive(Debug, Clone)]
49pub(crate) enum Protocol {
50    LegacyJson { mode: SpanMode },
51    BidirectionalPostcardPrototype { mode: SpanMode },
52}
53
54pub trait ProcessExit: Send + Sync {
55    fn exit_err(&mut self) -> Option<ServerError>;
56}
57
58impl ProcessExit for Process {
59    fn exit_err(&mut self) -> Option<ServerError> {
60        match self.child.try_wait() {
61            Ok(None) | Err(_) => None,
62            Ok(Some(status)) => {
63                let mut msg = String::new();
64                if !status.success()
65                    && let Some(stderr) = self.child.stderr.as_mut()
66                {
67                    _ = stderr.read_to_string(&mut msg);
68                }
69                Some(ServerError {
70                    message: format!(
71                        "proc-macro server exited with {status}{}{msg}",
72                        if msg.is_empty() { "" } else { ": " }
73                    ),
74                    io: None,
75                })
76            }
77        }
78    }
79}
80
81/// Maintains the state of the proc-macro server process.
82pub(crate) struct ProcessSrvState {
83    process: Box<dyn ProcessExit>,
84    stdin: Box<dyn Write + Send + Sync>,
85    stdout: Box<dyn BufRead + Send + Sync>,
86}
87
88impl ProcMacroServerProcess {
89    /// Starts the proc-macro server and performs a version check
90    pub(crate) fn spawn<'a>(
91        process_path: &AbsPath,
92        env: impl IntoIterator<
93            Item = (impl AsRef<std::ffi::OsStr>, &'a Option<impl 'a + AsRef<std::ffi::OsStr>>),
94        > + Clone,
95        version: Option<&Version>,
96    ) -> io::Result<ProcMacroServerProcess> {
97        Self::run(
98            |format| {
99                let mut process = Process::run(
100                    process_path,
101                    env.clone(),
102                    format.map(|format| format.to_string()).as_deref(),
103                )?;
104                let (stdin, stdout) = process.stdio().expect("couldn't access child stdio");
105
106                Ok((Box::new(process), Box::new(stdin), Box::new(stdout)))
107            },
108            version,
109            || {
110                #[expect(clippy::disallowed_methods)]
111                Command::new(process_path)
112                    .arg("--version")
113                    .output()
114                    .map(|output| String::from_utf8_lossy(&output.stdout).trim().to_owned())
115                    .unwrap_or_else(|_| "unknown version".to_owned())
116            },
117        )
118    }
119
120    /// Invokes `spawn` and performs a version check.
121    pub(crate) fn run(
122        spawn: impl Fn(
123            Option<ProtocolFormat>,
124        ) -> io::Result<(
125            Box<dyn ProcessExit>,
126            Box<dyn Write + Send + Sync>,
127            Box<dyn BufRead + Send + Sync>,
128        )>,
129        version: Option<&Version>,
130        binary_server_version: impl Fn() -> String,
131    ) -> io::Result<ProcMacroServerProcess> {
132        const VERSION: Version = Version::new(1, 93, 0);
133        // we do `>` for nightly as this started working in the middle of the 1.93 nightly release, so we dont want to break on half of the nightlies
134        let has_working_format_flag = version.map_or(false, |v| {
135            if v.pre.as_str() == "nightly" { *v > VERSION } else { *v >= VERSION }
136        });
137
138        let formats: &[_] = if std::env::var_os("RUST_ANALYZER_USE_POSTCARD").is_some()
139            && has_working_format_flag
140        {
141            &[
142                Some(ProtocolFormat::BidirectionalPostcardPrototype),
143                Some(ProtocolFormat::JsonLegacy),
144            ]
145        } else {
146            &[None]
147        };
148
149        let mut err = None;
150        for &format in formats {
151            let create_srv = || {
152                let (process, stdin, stdout) = spawn(format)?;
153
154                io::Result::Ok(ProcMacroServerProcess {
155                    state: Mutex::new(ProcessSrvState { process, stdin, stdout }),
156                    version: 0,
157                    protocol: match format {
158                        Some(ProtocolFormat::BidirectionalPostcardPrototype) => {
159                            Protocol::BidirectionalPostcardPrototype { mode: SpanMode::Id }
160                        }
161                        Some(ProtocolFormat::JsonLegacy) | None => {
162                            Protocol::LegacyJson { mode: SpanMode::Id }
163                        }
164                    },
165                    exited: OnceLock::new(),
166                    active: AtomicU32::new(0),
167                })
168            };
169            let mut srv = create_srv()?;
170            tracing::info!("sending proc-macro server version check");
171            match srv.version_check(Some(&reject_subrequests)) {
172                Ok(v) if v > version::CURRENT_API_VERSION => {
173                    let process_version = binary_server_version();
174                    err = Some(io::Error::other(format!(
175                        "Your installed proc-macro server is too new for your rust-analyzer. API version: {}, server version: {process_version}. \
176                        This will prevent proc-macro expansion from working. Please consider updating your rust-analyzer to ensure compatibility with your current toolchain.",
177                        version::CURRENT_API_VERSION
178                    )));
179                }
180                Ok(v) => {
181                    tracing::info!("Proc-macro server version: {v}");
182                    srv.version = v;
183                    if srv.version >= version::RUST_ANALYZER_SPAN_SUPPORT
184                        && let Ok(new_mode) =
185                            srv.enable_rust_analyzer_spans(Some(&reject_subrequests))
186                    {
187                        match &mut srv.protocol {
188                            Protocol::LegacyJson { mode }
189                            | Protocol::BidirectionalPostcardPrototype { mode } => *mode = new_mode,
190                        }
191                    }
192                    tracing::info!("Proc-macro server protocol: {:?}", srv.protocol);
193                    return Ok(srv);
194                }
195                Err(e) => {
196                    tracing::info!(%e, "proc-macro version check failed");
197                    err = Some(io::Error::other(format!(
198                        "proc-macro server version check failed: {e}"
199                    )))
200                }
201            }
202        }
203        Err(err.unwrap())
204    }
205
206    /// Finds proc-macros in a given dynamic library.
207    pub(crate) fn find_proc_macros(
208        &self,
209        dylib_path: &AbsPath,
210        callback: Option<SubCallback<'_>>,
211    ) -> Result<Result<Vec<(String, ProcMacroKind)>, String>, ServerError> {
212        match self.protocol {
213            Protocol::LegacyJson { .. } => legacy_protocol::find_proc_macros(self, dylib_path),
214
215            Protocol::BidirectionalPostcardPrototype { .. } => {
216                let cb = callback.expect("callback required for bidirectional protocol");
217                bidirectional_protocol::find_proc_macros(self, dylib_path, cb)
218            }
219        }
220    }
221
222    /// Returns the server error if the process has exited.
223    pub(crate) fn exited(&self) -> Option<&ServerError> {
224        self.exited.get().map(|it| &it.0)
225    }
226
227    /// Retrieves the API version of the proc-macro server.
228    pub(crate) fn version(&self) -> u32 {
229        self.version
230    }
231
232    /// Enable support for rust-analyzer span mode if the server supports it.
233    pub(crate) fn rust_analyzer_spans(&self) -> bool {
234        match self.protocol {
235            Protocol::LegacyJson { mode } => mode == SpanMode::RustAnalyzer,
236            Protocol::BidirectionalPostcardPrototype { mode } => mode == SpanMode::RustAnalyzer,
237        }
238    }
239
240    /// Checks the API version of the running proc-macro server.
241    fn version_check(&self, callback: Option<SubCallback<'_>>) -> Result<u32, ServerError> {
242        match self.protocol {
243            Protocol::LegacyJson { .. } => legacy_protocol::version_check(self),
244            Protocol::BidirectionalPostcardPrototype { .. } => {
245                let cb = callback.expect("callback required for bidirectional protocol");
246                bidirectional_protocol::version_check(self, cb)
247            }
248        }
249    }
250
251    /// Enable support for rust-analyzer span mode if the server supports it.
252    fn enable_rust_analyzer_spans(
253        &self,
254        callback: Option<SubCallback<'_>>,
255    ) -> Result<SpanMode, ServerError> {
256        match self.protocol {
257            Protocol::LegacyJson { .. } => legacy_protocol::enable_rust_analyzer_spans(self),
258            Protocol::BidirectionalPostcardPrototype { .. } => {
259                let cb = callback.expect("callback required for bidirectional protocol");
260                bidirectional_protocol::enable_rust_analyzer_spans(self, cb)
261            }
262        }
263    }
264
265    pub(crate) fn expand(
266        &self,
267        proc_macro: &ProcMacro,
268        subtree: tt::SubtreeView<'_>,
269        attr: Option<tt::SubtreeView<'_>>,
270        env: Vec<(String, String)>,
271        def_site: Span,
272        call_site: Span,
273        mixed_site: Span,
274        current_dir: String,
275        callback: Option<SubCallback<'_>>,
276    ) -> Result<Result<tt::TopSubtree, String>, ServerError> {
277        self.active.fetch_add(1, Ordering::AcqRel);
278        let result = match self.protocol {
279            Protocol::LegacyJson { .. } => legacy_protocol::expand(
280                proc_macro,
281                self,
282                subtree,
283                attr,
284                env,
285                def_site,
286                call_site,
287                mixed_site,
288                current_dir,
289            ),
290            Protocol::BidirectionalPostcardPrototype { .. } => bidirectional_protocol::expand(
291                proc_macro,
292                self,
293                subtree,
294                attr,
295                env,
296                def_site,
297                call_site,
298                mixed_site,
299                current_dir,
300                callback.expect("callback required for bidirectional protocol"),
301            ),
302        };
303
304        self.active.fetch_sub(1, Ordering::AcqRel);
305        result
306    }
307
308    pub(crate) fn send_task_legacy<Request, Response>(
309        &self,
310        send: impl FnOnce(
311            &mut dyn Write,
312            &mut dyn BufRead,
313            Request,
314            &mut String,
315        ) -> Result<Option<Response>, ServerError>,
316        req: Request,
317    ) -> Result<Response, ServerError> {
318        self.with_locked_io(String::new(), |writer, reader, buf| {
319            send(writer, reader, req, buf).and_then(|res| {
320                res.ok_or_else(|| {
321                    let message = "proc-macro server did not respond with data".to_owned();
322                    ServerError {
323                        io: Some(Arc::new(io::Error::new(
324                            io::ErrorKind::BrokenPipe,
325                            message.clone(),
326                        ))),
327                        message,
328                    }
329                })
330            })
331        })
332    }
333
334    fn with_locked_io<R, B>(
335        &self,
336        mut buf: B,
337        f: impl FnOnce(&mut dyn Write, &mut dyn BufRead, &mut B) -> Result<R, ServerError>,
338    ) -> Result<R, ServerError> {
339        let state = &mut *self.state.lock().unwrap();
340        f(&mut state.stdin, &mut state.stdout, &mut buf).map_err(|e| {
341            if e.io.as_ref().map(|it| it.kind()) == Some(io::ErrorKind::BrokenPipe) {
342                match state.process.exit_err() {
343                    None => e,
344                    Some(server_error) => {
345                        self.exited.get_or_init(|| AssertUnwindSafe(server_error)).0.clone()
346                    }
347                }
348            } else {
349                e
350            }
351        })
352    }
353
354    pub(crate) fn run_bidirectional(
355        &self,
356        initial: BidirectionalMessage,
357        callback: SubCallback<'_>,
358    ) -> Result<BidirectionalMessage, ServerError> {
359        self.with_locked_io(Vec::new(), |writer, reader, buf| {
360            bidirectional_protocol::run_conversation(writer, reader, buf, initial, callback)
361        })
362    }
363
364    pub(crate) fn number_of_active_req(&self) -> u32 {
365        self.active.load(Ordering::Acquire)
366    }
367}
368
369/// Manages the execution of the proc-macro server process.
370#[derive(Debug)]
371struct Process {
372    child: JodChild,
373}
374
375impl Process {
376    /// Runs a new proc-macro server process with the specified environment variables.
377    fn run<'a>(
378        path: &AbsPath,
379        env: impl IntoIterator<
380            Item = (impl AsRef<std::ffi::OsStr>, &'a Option<impl 'a + AsRef<std::ffi::OsStr>>),
381        >,
382        format: Option<&str>,
383    ) -> io::Result<Process> {
384        let child = JodChild(mk_child(path, env, format)?);
385        Ok(Process { child })
386    }
387
388    /// Retrieves stdin and stdout handles for the process.
389    fn stdio(&mut self) -> Option<(ChildStdin, BufReader<ChildStdout>)> {
390        let stdin = self.child.stdin.take()?;
391        let stdout = self.child.stdout.take()?;
392        let read = BufReader::new(stdout);
393
394        Some((stdin, read))
395    }
396}
397
398/// Creates and configures a new child process for the proc-macro server.
399fn mk_child<'a>(
400    path: &AbsPath,
401    extra_env: impl IntoIterator<
402        Item = (impl AsRef<std::ffi::OsStr>, &'a Option<impl 'a + AsRef<std::ffi::OsStr>>),
403    >,
404    format: Option<&str>,
405) -> io::Result<Child> {
406    #[allow(clippy::disallowed_methods)]
407    let mut cmd = Command::new(path);
408    for env in extra_env {
409        match env {
410            (key, Some(val)) => cmd.env(key, val),
411            (key, None) => cmd.env_remove(key),
412        };
413    }
414    if let Some(format) = format {
415        cmd.arg("--format");
416        cmd.arg(format);
417    }
418    cmd.env("RUST_ANALYZER_INTERNALS_DO_NOT_USE", "this is unstable")
419        .stdin(Stdio::piped())
420        .stdout(Stdio::piped())
421        .stderr(Stdio::inherit());
422    if cfg!(windows) {
423        let mut path_var = std::ffi::OsString::new();
424        path_var.push(path.parent().unwrap().parent().unwrap());
425        path_var.push("\\bin;");
426        path_var.push(std::env::var_os("PATH").unwrap_or_default());
427        cmd.env("PATH", path_var);
428    }
429    cmd.spawn()
430}