proc_macro_api/
process.rs

1//! Handle process life-time and message passing for proc-macro client
2
3use std::{
4    fmt::Debug,
5    io::{self, BufRead, BufReader, Read, Write},
6    panic::AssertUnwindSafe,
7    process::{Child, ChildStdin, ChildStdout, Command, Stdio},
8    sync::{
9        Arc, Mutex, OnceLock,
10        atomic::{AtomicU32, Ordering},
11    },
12};
13
14use paths::AbsPath;
15use semver::Version;
16use span::Span;
17use stdx::JodChild;
18
19use crate::{
20    ProcMacro, ProcMacroKind, ProtocolFormat, ServerError,
21    bidirectional_protocol::{
22        self, SubCallback,
23        msg::{BidirectionalMessage, SubResponse},
24        reject_subrequests,
25    },
26    legacy_protocol::{self, SpanMode},
27    version,
28};
29
30/// Represents a process handling proc-macro communication.
31pub(crate) struct ProcMacroServerProcess {
32    /// The state of the proc-macro server process, the protocol is currently strictly sequential
33    /// hence the lock on the state.
34    state: Mutex<ProcessSrvState>,
35    version: u32,
36    protocol: Protocol,
37    /// Populated when the server exits.
38    exited: OnceLock<AssertUnwindSafe<ServerError>>,
39    active: AtomicU32,
40}
41
42impl std::fmt::Debug for ProcMacroServerProcess {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        f.debug_struct("ProcMacroServerProcess")
45            .field("version", &self.version)
46            .field("protocol", &self.protocol)
47            .field("exited", &self.exited)
48            .finish()
49    }
50}
51
52#[derive(Debug, Clone)]
53pub(crate) enum Protocol {
54    LegacyJson { mode: SpanMode },
55    BidirectionalPostcardPrototype { mode: SpanMode },
56}
57
58pub trait ProcessExit: Send + Sync {
59    fn exit_err(&mut self) -> Option<ServerError>;
60}
61
62impl ProcessExit for Process {
63    fn exit_err(&mut self) -> Option<ServerError> {
64        match self.child.try_wait() {
65            Ok(None) | Err(_) => None,
66            Ok(Some(status)) => {
67                let mut msg = String::new();
68                if !status.success()
69                    && let Some(stderr) = self.child.stderr.as_mut()
70                {
71                    _ = stderr.read_to_string(&mut msg);
72                }
73                Some(ServerError {
74                    message: format!(
75                        "proc-macro server exited with {status}{}{msg}",
76                        if msg.is_empty() { "" } else { ": " }
77                    ),
78                    io: None,
79                })
80            }
81        }
82    }
83}
84
85/// Maintains the state of the proc-macro server process.
86pub(crate) struct ProcessSrvState {
87    process: Box<dyn ProcessExit>,
88    stdin: Box<dyn Write + Send + Sync>,
89    stdout: Box<dyn BufRead + Send + Sync>,
90}
91
92impl ProcMacroServerProcess {
93    /// Starts the proc-macro server and performs a version check
94    pub(crate) fn spawn<'a>(
95        process_path: &AbsPath,
96        env: impl IntoIterator<
97            Item = (impl AsRef<std::ffi::OsStr>, &'a Option<impl 'a + AsRef<std::ffi::OsStr>>),
98        > + Clone,
99        version: Option<&Version>,
100    ) -> io::Result<ProcMacroServerProcess> {
101        Self::run(
102            |format| {
103                let mut process = Process::run(
104                    process_path,
105                    env.clone(),
106                    format.map(|format| format.to_string()).as_deref(),
107                )?;
108                let (stdin, stdout) = process.stdio().expect("couldn't access child stdio");
109
110                Ok((Box::new(process), Box::new(stdin), Box::new(stdout)))
111            },
112            version,
113            || {
114                #[expect(clippy::disallowed_methods)]
115                Command::new(process_path)
116                    .arg("--version")
117                    .output()
118                    .map(|output| String::from_utf8_lossy(&output.stdout).trim().to_owned())
119                    .unwrap_or_else(|_| "unknown version".to_owned())
120            },
121        )
122    }
123
124    /// Invokes `spawn` and performs a version check.
125    pub(crate) fn run(
126        spawn: impl Fn(
127            Option<ProtocolFormat>,
128        ) -> io::Result<(
129            Box<dyn ProcessExit>,
130            Box<dyn Write + Send + Sync>,
131            Box<dyn BufRead + Send + Sync>,
132        )>,
133        version: Option<&Version>,
134        binary_server_version: impl Fn() -> String,
135    ) -> io::Result<ProcMacroServerProcess> {
136        const VERSION: Version = Version::new(1, 93, 0);
137        // we do `>` for nightly as this started working in the middle of the 1.93 nightly release, so we dont want to break on half of the nightlies
138        let has_working_format_flag = version.map_or(false, |v| {
139            if v.pre.as_str() == "nightly" { *v > VERSION } else { *v >= VERSION }
140        });
141
142        let formats: &[_] = if std::env::var_os("RUST_ANALYZER_USE_POSTCARD").is_some()
143            && has_working_format_flag
144        {
145            &[
146                Some(ProtocolFormat::BidirectionalPostcardPrototype),
147                Some(ProtocolFormat::JsonLegacy),
148            ]
149        } else {
150            &[None]
151        };
152
153        let mut err = None;
154        for &format in formats {
155            let create_srv = || {
156                let (process, stdin, stdout) = spawn(format)?;
157
158                io::Result::Ok(ProcMacroServerProcess {
159                    state: Mutex::new(ProcessSrvState { process, stdin, stdout }),
160                    version: 0,
161                    protocol: match format {
162                        Some(ProtocolFormat::BidirectionalPostcardPrototype) => {
163                            Protocol::BidirectionalPostcardPrototype { mode: SpanMode::Id }
164                        }
165                        Some(ProtocolFormat::JsonLegacy) | None => {
166                            Protocol::LegacyJson { mode: SpanMode::Id }
167                        }
168                    },
169                    exited: OnceLock::new(),
170                    active: AtomicU32::new(0),
171                })
172            };
173            let mut srv = create_srv()?;
174            tracing::info!("sending proc-macro server version check");
175            match srv.version_check(Some(&reject_subrequests)) {
176                Ok(v) if v > version::CURRENT_API_VERSION => {
177                    let process_version = binary_server_version();
178                    err = Some(io::Error::other(format!(
179                        "Your installed proc-macro server is too new for your rust-analyzer. API version: {}, server version: {process_version}. \
180                        This will prevent proc-macro expansion from working. Please consider updating your rust-analyzer to ensure compatibility with your current toolchain.",
181                        version::CURRENT_API_VERSION
182                    )));
183                }
184                Ok(v) => {
185                    tracing::info!("Proc-macro server version: {v}");
186                    srv.version = v;
187                    if srv.version >= version::RUST_ANALYZER_SPAN_SUPPORT
188                        && let Ok(new_mode) =
189                            srv.enable_rust_analyzer_spans(Some(&reject_subrequests))
190                    {
191                        match &mut srv.protocol {
192                            Protocol::LegacyJson { mode }
193                            | Protocol::BidirectionalPostcardPrototype { mode } => *mode = new_mode,
194                        }
195                    }
196                    tracing::info!("Proc-macro server protocol: {:?}", srv.protocol);
197                    return Ok(srv);
198                }
199                Err(e) => {
200                    tracing::info!(%e, "proc-macro version check failed");
201                    err = Some(io::Error::other(format!(
202                        "proc-macro server version check failed: {e}"
203                    )))
204                }
205            }
206        }
207        Err(err.unwrap())
208    }
209
210    /// Finds proc-macros in a given dynamic library.
211    pub(crate) fn find_proc_macros(
212        &self,
213        dylib_path: &AbsPath,
214    ) -> Result<Result<Vec<(String, ProcMacroKind)>, String>, ServerError> {
215        match self.protocol {
216            Protocol::LegacyJson { .. } => legacy_protocol::find_proc_macros(self, dylib_path),
217
218            Protocol::BidirectionalPostcardPrototype { .. } => {
219                bidirectional_protocol::find_proc_macros(self, dylib_path, &|_| {
220                    Ok(SubResponse::Cancel {
221                        reason: String::from(
222                            "Server should not do a sub request when loading proc-macros",
223                        ),
224                    })
225                })
226            }
227        }
228    }
229
230    /// Returns the server error if the process has exited.
231    pub(crate) fn exited(&self) -> Option<&ServerError> {
232        self.exited.get().map(|it| &it.0)
233    }
234
235    /// Retrieves the API version of the proc-macro server.
236    pub(crate) fn version(&self) -> u32 {
237        self.version
238    }
239
240    /// Enable support for rust-analyzer span mode if the server supports it.
241    pub(crate) fn rust_analyzer_spans(&self) -> bool {
242        match self.protocol {
243            Protocol::LegacyJson { mode } => mode == SpanMode::RustAnalyzer,
244            Protocol::BidirectionalPostcardPrototype { mode } => mode == SpanMode::RustAnalyzer,
245        }
246    }
247
248    /// Checks the API version of the running proc-macro server.
249    fn version_check(&self, callback: Option<SubCallback<'_>>) -> Result<u32, ServerError> {
250        match self.protocol {
251            Protocol::LegacyJson { .. } => legacy_protocol::version_check(self),
252            Protocol::BidirectionalPostcardPrototype { .. } => {
253                let cb = callback.expect("callback required for bidirectional protocol");
254                bidirectional_protocol::version_check(self, cb)
255            }
256        }
257    }
258
259    /// Enable support for rust-analyzer span mode if the server supports it.
260    fn enable_rust_analyzer_spans(
261        &self,
262        callback: Option<SubCallback<'_>>,
263    ) -> Result<SpanMode, ServerError> {
264        match self.protocol {
265            Protocol::LegacyJson { .. } => legacy_protocol::enable_rust_analyzer_spans(self),
266            Protocol::BidirectionalPostcardPrototype { .. } => {
267                let cb = callback.expect("callback required for bidirectional protocol");
268                bidirectional_protocol::enable_rust_analyzer_spans(self, cb)
269            }
270        }
271    }
272
273    pub(crate) fn expand(
274        &self,
275        proc_macro: &ProcMacro,
276        subtree: tt::SubtreeView<'_>,
277        attr: Option<tt::SubtreeView<'_>>,
278        env: Vec<(String, String)>,
279        def_site: Span,
280        call_site: Span,
281        mixed_site: Span,
282        current_dir: String,
283        callback: Option<SubCallback<'_>>,
284    ) -> Result<Result<tt::TopSubtree, String>, ServerError> {
285        self.active.fetch_add(1, Ordering::AcqRel);
286        let result = match self.protocol {
287            Protocol::LegacyJson { .. } => legacy_protocol::expand(
288                proc_macro,
289                self,
290                subtree,
291                attr,
292                env,
293                def_site,
294                call_site,
295                mixed_site,
296                current_dir,
297            ),
298            Protocol::BidirectionalPostcardPrototype { .. } => bidirectional_protocol::expand(
299                proc_macro,
300                self,
301                subtree,
302                attr,
303                env,
304                def_site,
305                call_site,
306                mixed_site,
307                current_dir,
308                callback.expect("callback required for bidirectional protocol"),
309            ),
310        };
311
312        self.active.fetch_sub(1, Ordering::AcqRel);
313        result
314    }
315
316    pub(crate) fn send_task_legacy<Request, Response>(
317        &self,
318        send: impl FnOnce(
319            &mut dyn Write,
320            &mut dyn BufRead,
321            Request,
322            &mut String,
323        ) -> Result<Option<Response>, ServerError>,
324        req: Request,
325    ) -> Result<Response, ServerError> {
326        self.with_locked_io(String::new(), |writer, reader, buf| {
327            send(writer, reader, req, buf).and_then(|res| {
328                res.ok_or_else(|| {
329                    let message = "proc-macro server did not respond with data".to_owned();
330                    ServerError {
331                        io: Some(Arc::new(io::Error::new(
332                            io::ErrorKind::BrokenPipe,
333                            message.clone(),
334                        ))),
335                        message,
336                    }
337                })
338            })
339        })
340    }
341
342    fn with_locked_io<R, B>(
343        &self,
344        mut buf: B,
345        f: impl FnOnce(&mut dyn Write, &mut dyn BufRead, &mut B) -> Result<R, ServerError>,
346    ) -> Result<R, ServerError> {
347        let state = &mut *self.state.lock().unwrap();
348        f(&mut state.stdin, &mut state.stdout, &mut buf).map_err(|e| {
349            if e.io.as_ref().map(|it| it.kind()) == Some(io::ErrorKind::BrokenPipe) {
350                match state.process.exit_err() {
351                    None => e,
352                    Some(server_error) => {
353                        self.exited.get_or_init(|| AssertUnwindSafe(server_error)).0.clone()
354                    }
355                }
356            } else {
357                e
358            }
359        })
360    }
361
362    pub(crate) fn run_bidirectional(
363        &self,
364        initial: BidirectionalMessage,
365        callback: SubCallback<'_>,
366    ) -> Result<BidirectionalMessage, ServerError> {
367        self.with_locked_io(Vec::new(), |writer, reader, buf| {
368            bidirectional_protocol::run_conversation(writer, reader, buf, initial, callback)
369        })
370    }
371
372    pub(crate) fn number_of_active_req(&self) -> u32 {
373        self.active.load(Ordering::Acquire)
374    }
375}
376
377/// Manages the execution of the proc-macro server process.
378#[derive(Debug)]
379struct Process {
380    child: JodChild,
381}
382
383impl Process {
384    /// Runs a new proc-macro server process with the specified environment variables.
385    fn run<'a>(
386        path: &AbsPath,
387        env: impl IntoIterator<
388            Item = (impl AsRef<std::ffi::OsStr>, &'a Option<impl 'a + AsRef<std::ffi::OsStr>>),
389        >,
390        format: Option<&str>,
391    ) -> io::Result<Process> {
392        let child = JodChild(mk_child(path, env, format)?);
393        Ok(Process { child })
394    }
395
396    /// Retrieves stdin and stdout handles for the process.
397    fn stdio(&mut self) -> Option<(ChildStdin, BufReader<ChildStdout>)> {
398        let stdin = self.child.stdin.take()?;
399        let stdout = self.child.stdout.take()?;
400        let read = BufReader::new(stdout);
401
402        Some((stdin, read))
403    }
404}
405
406/// Creates and configures a new child process for the proc-macro server.
407fn mk_child<'a>(
408    path: &AbsPath,
409    extra_env: impl IntoIterator<
410        Item = (impl AsRef<std::ffi::OsStr>, &'a Option<impl 'a + AsRef<std::ffi::OsStr>>),
411    >,
412    format: Option<&str>,
413) -> io::Result<Child> {
414    #[allow(clippy::disallowed_methods)]
415    let mut cmd = Command::new(path);
416    for env in extra_env {
417        match env {
418            (key, Some(val)) => cmd.env(key, val),
419            (key, None) => cmd.env_remove(key),
420        };
421    }
422    if let Some(format) = format {
423        cmd.arg("--format");
424        cmd.arg(format);
425    }
426    cmd.env("RUST_ANALYZER_INTERNALS_DO_NOT_USE", "this is unstable")
427        .stdin(Stdio::piped())
428        .stdout(Stdio::piped())
429        .stderr(Stdio::inherit());
430    if cfg!(windows) {
431        let mut path_var = std::ffi::OsString::new();
432        path_var.push(path.parent().unwrap().parent().unwrap());
433        path_var.push("\\bin;");
434        path_var.push(std::env::var_os("PATH").unwrap_or_default());
435        cmd.env("PATH", path_var);
436    }
437    cmd.spawn()
438}