1use std::{
4 fmt::Debug,
5 io::{self, BufRead, BufReader, Read, Write},
6 panic::AssertUnwindSafe,
7 process::{Child, ChildStdin, ChildStdout, Command, Stdio},
8 sync::{
9 Arc, Mutex, OnceLock,
10 atomic::{AtomicU32, Ordering},
11 },
12};
13
14use paths::AbsPath;
15use semver::Version;
16use span::Span;
17use stdx::JodChild;
18
19use crate::{
20 ProcMacro, ProcMacroKind, ProtocolFormat, ServerError,
21 bidirectional_protocol::{self, SubCallback, msg::BidirectionalMessage, reject_subrequests},
22 legacy_protocol::{self, SpanMode},
23 version,
24};
25
26pub(crate) struct ProcMacroServerProcess {
28 state: Mutex<ProcessSrvState>,
31 version: u32,
32 protocol: Protocol,
33 exited: OnceLock<AssertUnwindSafe<ServerError>>,
35 active: AtomicU32,
36}
37
38impl std::fmt::Debug for ProcMacroServerProcess {
39 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40 f.debug_struct("ProcMacroServerProcess")
41 .field("version", &self.version)
42 .field("protocol", &self.protocol)
43 .field("exited", &self.exited)
44 .finish()
45 }
46}
47
48#[derive(Debug, Clone)]
49pub(crate) enum Protocol {
50 LegacyJson { mode: SpanMode },
51 BidirectionalPostcardPrototype { mode: SpanMode },
52}
53
54pub trait ProcessExit: Send + Sync {
55 fn exit_err(&mut self) -> Option<ServerError>;
56}
57
58impl ProcessExit for Process {
59 fn exit_err(&mut self) -> Option<ServerError> {
60 match self.child.try_wait() {
61 Ok(None) | Err(_) => None,
62 Ok(Some(status)) => {
63 let mut msg = String::new();
64 if !status.success()
65 && let Some(stderr) = self.child.stderr.as_mut()
66 {
67 _ = stderr.read_to_string(&mut msg);
68 }
69 Some(ServerError {
70 message: format!(
71 "proc-macro server exited with {status}{}{msg}",
72 if msg.is_empty() { "" } else { ": " }
73 ),
74 io: None,
75 })
76 }
77 }
78 }
79}
80
81pub(crate) struct ProcessSrvState {
83 process: Box<dyn ProcessExit>,
84 stdin: Box<dyn Write + Send + Sync>,
85 stdout: Box<dyn BufRead + Send + Sync>,
86}
87
88impl ProcMacroServerProcess {
89 pub(crate) fn spawn<'a>(
91 process_path: &AbsPath,
92 env: impl IntoIterator<
93 Item = (impl AsRef<std::ffi::OsStr>, &'a Option<impl 'a + AsRef<std::ffi::OsStr>>),
94 > + Clone,
95 version: Option<&Version>,
96 ) -> io::Result<ProcMacroServerProcess> {
97 Self::run(
98 |format| {
99 let mut process = Process::run(
100 process_path,
101 env.clone(),
102 format.map(|format| format.to_string()).as_deref(),
103 )?;
104 let (stdin, stdout) = process.stdio().expect("couldn't access child stdio");
105
106 Ok((Box::new(process), Box::new(stdin), Box::new(stdout)))
107 },
108 version,
109 || {
110 #[expect(clippy::disallowed_methods)]
111 Command::new(process_path)
112 .arg("--version")
113 .output()
114 .map(|output| String::from_utf8_lossy(&output.stdout).trim().to_owned())
115 .unwrap_or_else(|_| "unknown version".to_owned())
116 },
117 )
118 }
119
120 pub(crate) fn run(
122 spawn: impl Fn(
123 Option<ProtocolFormat>,
124 ) -> io::Result<(
125 Box<dyn ProcessExit>,
126 Box<dyn Write + Send + Sync>,
127 Box<dyn BufRead + Send + Sync>,
128 )>,
129 version: Option<&Version>,
130 binary_server_version: impl Fn() -> String,
131 ) -> io::Result<ProcMacroServerProcess> {
132 const VERSION: Version = Version::new(1, 93, 0);
133 let has_working_format_flag = version.map_or(false, |v| {
135 if v.pre.as_str() == "nightly" { *v > VERSION } else { *v >= VERSION }
136 });
137
138 let formats: &[_] = if std::env::var_os("RUST_ANALYZER_USE_POSTCARD").is_some()
139 && has_working_format_flag
140 {
141 &[
142 Some(ProtocolFormat::BidirectionalPostcardPrototype),
143 Some(ProtocolFormat::JsonLegacy),
144 ]
145 } else {
146 &[None]
147 };
148
149 let mut err = None;
150 for &format in formats {
151 let create_srv = || {
152 let (process, stdin, stdout) = spawn(format)?;
153
154 io::Result::Ok(ProcMacroServerProcess {
155 state: Mutex::new(ProcessSrvState { process, stdin, stdout }),
156 version: 0,
157 protocol: match format {
158 Some(ProtocolFormat::BidirectionalPostcardPrototype) => {
159 Protocol::BidirectionalPostcardPrototype { mode: SpanMode::Id }
160 }
161 Some(ProtocolFormat::JsonLegacy) | None => {
162 Protocol::LegacyJson { mode: SpanMode::Id }
163 }
164 },
165 exited: OnceLock::new(),
166 active: AtomicU32::new(0),
167 })
168 };
169 let mut srv = create_srv()?;
170 tracing::info!("sending proc-macro server version check");
171 match srv.version_check(Some(&reject_subrequests)) {
172 Ok(v) if v > version::CURRENT_API_VERSION => {
173 let process_version = binary_server_version();
174 err = Some(io::Error::other(format!(
175 "Your installed proc-macro server is too new for your rust-analyzer. API version: {}, server version: {process_version}. \
176 This will prevent proc-macro expansion from working. Please consider updating your rust-analyzer to ensure compatibility with your current toolchain.",
177 version::CURRENT_API_VERSION
178 )));
179 }
180 Ok(v) => {
181 tracing::info!("Proc-macro server version: {v}");
182 srv.version = v;
183 if srv.version >= version::RUST_ANALYZER_SPAN_SUPPORT
184 && let Ok(new_mode) =
185 srv.enable_rust_analyzer_spans(Some(&reject_subrequests))
186 {
187 match &mut srv.protocol {
188 Protocol::LegacyJson { mode }
189 | Protocol::BidirectionalPostcardPrototype { mode } => *mode = new_mode,
190 }
191 }
192 tracing::info!("Proc-macro server protocol: {:?}", srv.protocol);
193 return Ok(srv);
194 }
195 Err(e) => {
196 tracing::info!(%e, "proc-macro version check failed");
197 err = Some(io::Error::other(format!(
198 "proc-macro server version check failed: {e}"
199 )))
200 }
201 }
202 }
203 Err(err.unwrap())
204 }
205
206 pub(crate) fn find_proc_macros(
208 &self,
209 dylib_path: &AbsPath,
210 callback: Option<SubCallback<'_>>,
211 ) -> Result<Result<Vec<(String, ProcMacroKind)>, String>, ServerError> {
212 match self.protocol {
213 Protocol::LegacyJson { .. } => legacy_protocol::find_proc_macros(self, dylib_path),
214
215 Protocol::BidirectionalPostcardPrototype { .. } => {
216 let cb = callback.expect("callback required for bidirectional protocol");
217 bidirectional_protocol::find_proc_macros(self, dylib_path, cb)
218 }
219 }
220 }
221
222 pub(crate) fn exited(&self) -> Option<&ServerError> {
224 self.exited.get().map(|it| &it.0)
225 }
226
227 pub(crate) fn version(&self) -> u32 {
229 self.version
230 }
231
232 pub(crate) fn rust_analyzer_spans(&self) -> bool {
234 match self.protocol {
235 Protocol::LegacyJson { mode } => mode == SpanMode::RustAnalyzer,
236 Protocol::BidirectionalPostcardPrototype { mode } => mode == SpanMode::RustAnalyzer,
237 }
238 }
239
240 fn version_check(&self, callback: Option<SubCallback<'_>>) -> Result<u32, ServerError> {
242 match self.protocol {
243 Protocol::LegacyJson { .. } => legacy_protocol::version_check(self),
244 Protocol::BidirectionalPostcardPrototype { .. } => {
245 let cb = callback.expect("callback required for bidirectional protocol");
246 bidirectional_protocol::version_check(self, cb)
247 }
248 }
249 }
250
251 fn enable_rust_analyzer_spans(
253 &self,
254 callback: Option<SubCallback<'_>>,
255 ) -> Result<SpanMode, ServerError> {
256 match self.protocol {
257 Protocol::LegacyJson { .. } => legacy_protocol::enable_rust_analyzer_spans(self),
258 Protocol::BidirectionalPostcardPrototype { .. } => {
259 let cb = callback.expect("callback required for bidirectional protocol");
260 bidirectional_protocol::enable_rust_analyzer_spans(self, cb)
261 }
262 }
263 }
264
265 pub(crate) fn expand(
266 &self,
267 proc_macro: &ProcMacro,
268 subtree: tt::SubtreeView<'_>,
269 attr: Option<tt::SubtreeView<'_>>,
270 env: Vec<(String, String)>,
271 def_site: Span,
272 call_site: Span,
273 mixed_site: Span,
274 current_dir: String,
275 callback: Option<SubCallback<'_>>,
276 ) -> Result<Result<tt::TopSubtree, String>, ServerError> {
277 self.active.fetch_add(1, Ordering::AcqRel);
278 let result = match self.protocol {
279 Protocol::LegacyJson { .. } => legacy_protocol::expand(
280 proc_macro,
281 self,
282 subtree,
283 attr,
284 env,
285 def_site,
286 call_site,
287 mixed_site,
288 current_dir,
289 ),
290 Protocol::BidirectionalPostcardPrototype { .. } => bidirectional_protocol::expand(
291 proc_macro,
292 self,
293 subtree,
294 attr,
295 env,
296 def_site,
297 call_site,
298 mixed_site,
299 current_dir,
300 callback.expect("callback required for bidirectional protocol"),
301 ),
302 };
303
304 self.active.fetch_sub(1, Ordering::AcqRel);
305 result
306 }
307
308 pub(crate) fn send_task_legacy<Request, Response>(
309 &self,
310 send: impl FnOnce(
311 &mut dyn Write,
312 &mut dyn BufRead,
313 Request,
314 &mut String,
315 ) -> Result<Option<Response>, ServerError>,
316 req: Request,
317 ) -> Result<Response, ServerError> {
318 self.with_locked_io(String::new(), |writer, reader, buf| {
319 send(writer, reader, req, buf).and_then(|res| {
320 res.ok_or_else(|| {
321 let message = "proc-macro server did not respond with data".to_owned();
322 ServerError {
323 io: Some(Arc::new(io::Error::new(
324 io::ErrorKind::BrokenPipe,
325 message.clone(),
326 ))),
327 message,
328 }
329 })
330 })
331 })
332 }
333
334 fn with_locked_io<R, B>(
335 &self,
336 mut buf: B,
337 f: impl FnOnce(&mut dyn Write, &mut dyn BufRead, &mut B) -> Result<R, ServerError>,
338 ) -> Result<R, ServerError> {
339 let state = &mut *self.state.lock().unwrap();
340 f(&mut state.stdin, &mut state.stdout, &mut buf).map_err(|e| {
341 if e.io.as_ref().map(|it| it.kind()) == Some(io::ErrorKind::BrokenPipe) {
342 match state.process.exit_err() {
343 None => e,
344 Some(server_error) => {
345 self.exited.get_or_init(|| AssertUnwindSafe(server_error)).0.clone()
346 }
347 }
348 } else {
349 e
350 }
351 })
352 }
353
354 pub(crate) fn run_bidirectional(
355 &self,
356 initial: BidirectionalMessage,
357 callback: SubCallback<'_>,
358 ) -> Result<BidirectionalMessage, ServerError> {
359 self.with_locked_io(Vec::new(), |writer, reader, buf| {
360 bidirectional_protocol::run_conversation(writer, reader, buf, initial, callback)
361 })
362 }
363
364 pub(crate) fn number_of_active_req(&self) -> u32 {
365 self.active.load(Ordering::Acquire)
366 }
367}
368
369#[derive(Debug)]
371struct Process {
372 child: JodChild,
373}
374
375impl Process {
376 fn run<'a>(
378 path: &AbsPath,
379 env: impl IntoIterator<
380 Item = (impl AsRef<std::ffi::OsStr>, &'a Option<impl 'a + AsRef<std::ffi::OsStr>>),
381 >,
382 format: Option<&str>,
383 ) -> io::Result<Process> {
384 let child = JodChild(mk_child(path, env, format)?);
385 Ok(Process { child })
386 }
387
388 fn stdio(&mut self) -> Option<(ChildStdin, BufReader<ChildStdout>)> {
390 let stdin = self.child.stdin.take()?;
391 let stdout = self.child.stdout.take()?;
392 let read = BufReader::new(stdout);
393
394 Some((stdin, read))
395 }
396}
397
398fn mk_child<'a>(
400 path: &AbsPath,
401 extra_env: impl IntoIterator<
402 Item = (impl AsRef<std::ffi::OsStr>, &'a Option<impl 'a + AsRef<std::ffi::OsStr>>),
403 >,
404 format: Option<&str>,
405) -> io::Result<Child> {
406 #[allow(clippy::disallowed_methods)]
407 let mut cmd = Command::new(path);
408 for env in extra_env {
409 match env {
410 (key, Some(val)) => cmd.env(key, val),
411 (key, None) => cmd.env_remove(key),
412 };
413 }
414 if let Some(format) = format {
415 cmd.arg("--format");
416 cmd.arg(format);
417 }
418 cmd.env("RUST_ANALYZER_INTERNALS_DO_NOT_USE", "this is unstable")
419 .stdin(Stdio::piped())
420 .stdout(Stdio::piped())
421 .stderr(Stdio::inherit());
422 if cfg!(windows) {
423 let mut path_var = std::ffi::OsString::new();
424 path_var.push(path.parent().unwrap().parent().unwrap());
425 path_var.push("\\bin;");
426 path_var.push(std::env::var_os("PATH").unwrap_or_default());
427 cmd.env("PATH", path_var);
428 }
429 cmd.spawn()
430}