diff --git a/axum/src/routing/method_routing.rs b/axum/src/routing/method_routing.rs index c011f1e8..8bf8c25a 100644 --- a/axum/src/routing/method_routing.rs +++ b/axum/src/routing/method_routing.rs @@ -917,27 +917,11 @@ where panic!("Overlapping method route. Cannot merge two method routes that both define `{name}`") } } - (Some(svc), None) => Some(svc), - (None, Some(svc)) => Some(svc), + (Some(svc), None) | (None, Some(svc)) => Some(svc), (None, None) => None, } } - #[track_caller] - fn merge_fallback( - fallback: Fallback, - fallback_other: Fallback, - ) -> Fallback { - match (fallback, fallback_other) { - (pick @ Fallback::Default(_), Fallback::Default(_)) => pick, - (Fallback::Default(_), pick @ Fallback::Custom(_)) => pick, - (pick @ Fallback::Custom(_), Fallback::Default(_)) => pick, - (Fallback::Custom(_), Fallback::Custom(_)) => { - panic!("Cannot merge two `MethodRouter`s that both have a fallback") - } - } - } - self.get = merge_inner(path, "GET", self.get, other.get); self.head = merge_inner(path, "HEAD", self.head, other.head); self.delete = merge_inner(path, "DELETE", self.delete, other.delete); @@ -947,7 +931,10 @@ where self.put = merge_inner(path, "PUT", self.put, other.put); self.trace = merge_inner(path, "TRACE", self.trace, other.trace); - self.fallback = merge_fallback(self.fallback, other.fallback); + self.fallback = self + .fallback + .merge(other.fallback) + .expect("Cannot merge two `MethodRouter`s that both have a fallback"); self.allow_header = self.allow_header.merge(other.allow_header); diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs index d33918e1..1a107560 100644 --- a/axum/src/routing/mod.rs +++ b/axum/src/routing/mod.rs @@ -323,14 +323,10 @@ where }; } - self.fallback = match (self.fallback, fallback) { - (Fallback::Default(_), pick @ Fallback::Default(_)) => pick, - (Fallback::Default(_), pick @ Fallback::Custom(_)) => pick, - (pick @ Fallback::Custom(_), Fallback::Default(_)) => pick, - (Fallback::Custom(_), Fallback::Custom(_)) => { - panic!("Cannot merge two `Router`s that both have a fallback") - } - }; + self.fallback = self + .fallback + .merge(fallback) + .expect("Cannot merge two `Router`s that both have a fallback"); self } @@ -628,11 +624,21 @@ enum Fallback { Custom(Route), } +impl Fallback { + fn merge(self, other: Self) -> Option { + match (self, other) { + (Self::Default(_), pick @ Self::Default(_)) => Some(pick), + (Self::Default(_), pick) | (pick, Self::Default(_)) => Some(pick), + _ => None, + } + } +} + impl Clone for Fallback { fn clone(&self) -> Self { match self { - Fallback::Default(inner) => Fallback::Default(inner.clone()), - Fallback::Custom(inner) => Fallback::Custom(inner.clone()), + Self::Default(inner) => Self::Default(inner.clone()), + Self::Custom(inner) => Self::Custom(inner.clone()), } } } @@ -652,8 +658,8 @@ impl Fallback { F: FnOnce(Route) -> Route, { match self { - Fallback::Default(inner) => Fallback::Default(f(inner)), - Fallback::Custom(inner) => Fallback::Custom(f(inner)), + Self::Default(inner) => Fallback::Default(f(inner)), + Self::Custom(inner) => Fallback::Custom(f(inner)), } } } @@ -666,8 +672,8 @@ enum Endpoint { impl Clone for Endpoint { fn clone(&self) -> Self { match self { - Endpoint::MethodRouter(inner) => Endpoint::MethodRouter(inner.clone()), - Endpoint::Route(inner) => Endpoint::Route(inner.clone()), + Self::MethodRouter(inner) => Self::MethodRouter(inner.clone()), + Self::Route(inner) => Self::Route(inner.clone()), } } }