From 312b7b3704e012f179fe4b57505a0509459fd958 Mon Sep 17 00:00:00 2001 From: Rob Young Date: Mon, 30 Nov 2020 21:01:11 +0000 Subject: [PATCH] Implement DurationRound for NaiveDateTime This is off the back of [this comment](https://github.com/chronotope/chrono/pull/445#issuecomment-717323407). --- CHANGELOG.md | 1 + src/round.rs | 200 +++++++++++++++++++++++++++++++++++++++------------ 2 files changed, 156 insertions(+), 45 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 039afd48..16a9063b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ Versions with only mechanical changes will be omitted from the following list. * Add more formatting documentation and examples. * Add support for microseconds timestamps serde serialization/deserialization (#304) * Fix `DurationRound` is not TZ aware (#495) +* Implement `DurationRound` for `NaiveDateTime` ## 0.4.19 diff --git a/src/round.rs b/src/round.rs index f391d1d8..61f53d34 100644 --- a/src/round.rs +++ b/src/round.rs @@ -6,6 +6,7 @@ use core::fmt; use core::marker::Sized; use core::ops::{Add, Sub}; use datetime::DateTime; +use naive::NaiveDateTime; use oldtime::Duration; #[cfg(any(feature = "std", test))] use std; @@ -150,56 +151,86 @@ impl DurationRound for DateTime { type Err = RoundingError; fn duration_round(self, duration: Duration) -> Result { - if let Some(span) = duration.num_nanoseconds() { - let naive = self.naive_local(); - - if naive.timestamp().abs() > MAX_SECONDS_TIMESTAMP_FOR_NANOS { - return Err(RoundingError::TimestampExceedsLimit); - } - let stamp = naive.timestamp_nanos(); - if span > stamp.abs() { - return Err(RoundingError::DurationExceedsTimestamp); - } - let delta_down = stamp % span; - if delta_down == 0 { - Ok(self) - } else { - let (delta_up, delta_down) = if delta_down < 0 { - (delta_down.abs(), span - delta_down.abs()) - } else { - (span - delta_down, delta_down) - }; - if delta_up <= delta_down { - Ok(self + Duration::nanoseconds(delta_up)) - } else { - Ok(self - Duration::nanoseconds(delta_down)) - } - } - } else { - Err(RoundingError::DurationExceedsLimit) - } + duration_round(self.naive_local(), self, duration) } fn duration_trunc(self, duration: Duration) -> Result { - if let Some(span) = duration.num_nanoseconds() { - let naive = self.naive_local(); + duration_trunc(self.naive_local(), self, duration) + } +} - if naive.timestamp().abs() > MAX_SECONDS_TIMESTAMP_FOR_NANOS { - return Err(RoundingError::TimestampExceedsLimit); - } - let stamp = naive.timestamp_nanos(); - if span > stamp.abs() { - return Err(RoundingError::DurationExceedsTimestamp); - } - let delta_down = stamp % span; - match delta_down.cmp(&0) { - Ordering::Equal => Ok(self), - Ordering::Greater => Ok(self - Duration::nanoseconds(delta_down)), - Ordering::Less => Ok(self - Duration::nanoseconds(span - delta_down.abs())), - } - } else { - Err(RoundingError::DurationExceedsLimit) +impl DurationRound for NaiveDateTime { + type Err = RoundingError; + + fn duration_round(self, duration: Duration) -> Result { + duration_round(self, self, duration) + } + + fn duration_trunc(self, duration: Duration) -> Result { + duration_trunc(self, self, duration) + } +} + +fn duration_round( + naive: NaiveDateTime, + original: T, + duration: Duration, +) -> Result +where + T: Timelike + Add + Sub, +{ + if let Some(span) = duration.num_nanoseconds() { + if naive.timestamp().abs() > MAX_SECONDS_TIMESTAMP_FOR_NANOS { + return Err(RoundingError::TimestampExceedsLimit); } + let stamp = naive.timestamp_nanos(); + if span > stamp.abs() { + return Err(RoundingError::DurationExceedsTimestamp); + } + let delta_down = stamp % span; + if delta_down == 0 { + Ok(original) + } else { + let (delta_up, delta_down) = if delta_down < 0 { + (delta_down.abs(), span - delta_down.abs()) + } else { + (span - delta_down, delta_down) + }; + if delta_up <= delta_down { + Ok(original + Duration::nanoseconds(delta_up)) + } else { + Ok(original - Duration::nanoseconds(delta_down)) + } + } + } else { + Err(RoundingError::DurationExceedsLimit) + } +} + +fn duration_trunc( + naive: NaiveDateTime, + original: T, + duration: Duration, +) -> Result +where + T: Timelike + Add + Sub, +{ + if let Some(span) = duration.num_nanoseconds() { + if naive.timestamp().abs() > MAX_SECONDS_TIMESTAMP_FOR_NANOS { + return Err(RoundingError::TimestampExceedsLimit); + } + let stamp = naive.timestamp_nanos(); + if span > stamp.abs() { + return Err(RoundingError::DurationExceedsTimestamp); + } + let delta_down = stamp % span; + match delta_down.cmp(&0) { + Ordering::Equal => Ok(original), + Ordering::Greater => Ok(original - Duration::nanoseconds(delta_down)), + Ordering::Less => Ok(original - Duration::nanoseconds(span - delta_down.abs())), + } + } else { + Err(RoundingError::DurationExceedsLimit) } } @@ -423,6 +454,46 @@ mod tests { ); } + #[test] + fn test_duration_round_naive() { + let dt = Utc.ymd(2016, 12, 31).and_hms_nano(23, 59, 59, 175_500_000).naive_utc(); + + assert_eq!( + dt.duration_round(Duration::milliseconds(10)).unwrap().to_string(), + "2016-12-31 23:59:59.180" + ); + + // round up + let dt = Utc.ymd(2012, 12, 12).and_hms_milli(18, 22, 30, 0).naive_utc(); + assert_eq!( + dt.duration_round(Duration::minutes(5)).unwrap().to_string(), + "2012-12-12 18:25:00" + ); + // round down + let dt = Utc.ymd(2012, 12, 12).and_hms_milli(18, 22, 29, 999).naive_utc(); + assert_eq!( + dt.duration_round(Duration::minutes(5)).unwrap().to_string(), + "2012-12-12 18:20:00" + ); + + assert_eq!( + dt.duration_round(Duration::minutes(10)).unwrap().to_string(), + "2012-12-12 18:20:00" + ); + assert_eq!( + dt.duration_round(Duration::minutes(30)).unwrap().to_string(), + "2012-12-12 18:30:00" + ); + assert_eq!( + dt.duration_round(Duration::hours(1)).unwrap().to_string(), + "2012-12-12 18:00:00" + ); + assert_eq!( + dt.duration_round(Duration::days(1)).unwrap().to_string(), + "2012-12-13 00:00:00" + ); + } + #[test] fn test_duration_round_pre_epoch() { let dt = Utc.ymd(1969, 12, 12).and_hms(12, 12, 12); @@ -493,6 +564,45 @@ mod tests { ); } + #[test] + fn test_duration_trunc_naive() { + let dt = Utc.ymd(2016, 12, 31).and_hms_nano(23, 59, 59, 1_75_500_000).naive_utc(); + + assert_eq!( + dt.duration_trunc(Duration::milliseconds(10)).unwrap().to_string(), + "2016-12-31 23:59:59.170" + ); + + // would round up + let dt = Utc.ymd(2012, 12, 12).and_hms_milli(18, 22, 30, 0).naive_utc(); + assert_eq!( + dt.duration_trunc(Duration::minutes(5)).unwrap().to_string(), + "2012-12-12 18:20:00" + ); + // would round down + let dt = Utc.ymd(2012, 12, 12).and_hms_milli(18, 22, 29, 999).naive_utc(); + assert_eq!( + dt.duration_trunc(Duration::minutes(5)).unwrap().to_string(), + "2012-12-12 18:20:00" + ); + assert_eq!( + dt.duration_trunc(Duration::minutes(10)).unwrap().to_string(), + "2012-12-12 18:20:00" + ); + assert_eq!( + dt.duration_trunc(Duration::minutes(30)).unwrap().to_string(), + "2012-12-12 18:00:00" + ); + assert_eq!( + dt.duration_trunc(Duration::hours(1)).unwrap().to_string(), + "2012-12-12 18:00:00" + ); + assert_eq!( + dt.duration_trunc(Duration::days(1)).unwrap().to_string(), + "2012-12-12 00:00:00" + ); + } + #[test] fn test_duration_trunc_pre_epoch() { let dt = Utc.ymd(1969, 12, 12).and_hms(12, 12, 12);