From 13a96efacbf2068b2ed019a393132950fe3762ab Mon Sep 17 00:00:00 2001 From: "Dustin J. Mitchell" Date: Thu, 7 Oct 2021 21:24:27 -0400 Subject: [PATCH] Add snapshot encoding / decoding --- Cargo.lock | 1 + taskchampion/Cargo.toml | 1 + taskchampion/src/taskdb/mod.rs | 1 + taskchampion/src/taskdb/snapshot.rs | 186 ++++++++++++++++++++++++++++ 4 files changed, 189 insertions(+) create mode 100644 taskchampion/src/taskdb/snapshot.rs diff --git a/Cargo.lock b/Cargo.lock index bb79505da..10f51edf9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2962,6 +2962,7 @@ version = "0.4.1" dependencies = [ "anyhow", "chrono", + "flate2", "log", "pretty_assertions", "proptest", diff --git a/taskchampion/Cargo.toml b/taskchampion/Cargo.toml index 168d4ba1f..c5da9898c 100644 --- a/taskchampion/Cargo.toml +++ b/taskchampion/Cargo.toml @@ -23,6 +23,7 @@ tindercrypt = { version = "^0.2.2", default-features = false } rusqlite = { version = "0.25", features = ["bundled"] } strum = "0.21" strum_macros = "0.21" +flate2 = "1" [dev-dependencies] proptest = "^1.0.0" diff --git a/taskchampion/src/taskdb/mod.rs b/taskchampion/src/taskdb/mod.rs index 850628719..0394b8c85 100644 --- a/taskchampion/src/taskdb/mod.rs +++ b/taskchampion/src/taskdb/mod.rs @@ -3,6 +3,7 @@ use crate::storage::{Operation, Storage, TaskMap}; use uuid::Uuid; mod ops; +mod snapshot; mod sync; mod working_set; diff --git a/taskchampion/src/taskdb/snapshot.rs b/taskchampion/src/taskdb/snapshot.rs new file mode 100644 index 000000000..e054612b3 --- /dev/null +++ b/taskchampion/src/taskdb/snapshot.rs @@ -0,0 +1,186 @@ +use crate::storage::{StorageTxn, TaskMap, VersionId}; +use flate2::{read::ZlibDecoder, write::ZlibEncoder, Compression}; +use serde::de::{Deserialize, Deserializer, MapAccess, Visitor}; +use serde::ser::{Serialize, SerializeMap, Serializer}; +use std::fmt; +use uuid::Uuid; + +/// A newtype to wrap the result of [`crate::storage::StorageTxn::all_tasks`] +pub(super) struct SnapshotTasks(Vec<(Uuid, TaskMap)>); + +impl Serialize for SnapshotTasks { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut map = serializer.serialize_map(Some(self.0.len()))?; + for (k, v) in &self.0 { + map.serialize_entry(k, v)?; + } + map.end() + } +} + +struct TaskDbVisitor; + +impl<'de> Visitor<'de> for TaskDbVisitor { + type Value = SnapshotTasks; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a map representing a task snapshot") + } + + fn visit_map(self, mut access: M) -> Result + where + M: MapAccess<'de>, + { + let mut map = SnapshotTasks(Vec::with_capacity(access.size_hint().unwrap_or(0))); + + while let Some((key, value)) = access.next_entry()? { + map.0.push((key, value)); + } + + Ok(map) + } +} + +impl<'de> Deserialize<'de> for SnapshotTasks { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_map(TaskDbVisitor) + } +} + +impl SnapshotTasks { + pub(super) fn encode(&self) -> anyhow::Result> { + let mut encoder = ZlibEncoder::new(Vec::new(), Compression::default()); + serde_json::to_writer(&mut encoder, &self)?; + Ok(encoder.finish()?) + } + + pub(super) fn decode(snapshot: &[u8]) -> anyhow::Result { + let decoder = ZlibDecoder::new(snapshot); + Ok(serde_json::from_reader(decoder)?) + } + + pub(super) fn into_inner(self) -> Vec<(Uuid, TaskMap)> { + self.0 + } +} + +#[allow(dead_code)] +/// Generate a snapshot (compressed, unencrypted) for the current state of the taskdb in the given +/// storage. +pub(super) fn make_snapshot(txn: &mut dyn StorageTxn) -> anyhow::Result> { + let all_tasks = SnapshotTasks(txn.all_tasks()?); + all_tasks.encode() +} + +#[allow(dead_code)] +/// Apply the given snapshot (compressed, unencrypted) to the taskdb's storage. +pub(super) fn apply_snapshot( + txn: &mut dyn StorageTxn, + version: VersionId, + snapshot: &[u8], +) -> anyhow::Result<()> { + let all_tasks = SnapshotTasks::decode(snapshot)?; + + // first, verify that the taskdb truly is empty + let mut empty = true; + empty = empty && txn.all_tasks()?.is_empty(); + empty = empty && txn.get_working_set()? == vec![None]; + empty = empty && txn.base_version()? == Uuid::nil(); + empty = empty && txn.operations()?.is_empty(); + + if !empty { + anyhow::bail!("Cannot apply snapshot to a non-empty task database"); + } + + for (uuid, task) in all_tasks.into_inner().drain(..) { + txn.set_task(uuid, task)?; + } + txn.set_base_version(version)?; + + Ok(()) +} + +#[cfg(test)] +mod test { + use super::*; + use crate::storage::{InMemoryStorage, Storage, TaskMap}; + use pretty_assertions::assert_eq; + + #[test] + fn test_serialize_empty() -> anyhow::Result<()> { + let empty = SnapshotTasks(vec![]); + assert_eq!(serde_json::to_vec(&empty)?, b"{}".to_owned()); + Ok(()) + } + + #[test] + fn test_serialize_tasks() -> anyhow::Result<()> { + let u = Uuid::new_v4(); + let m: TaskMap = vec![("description".to_owned(), "my task".to_owned())] + .drain(..) + .collect(); + let all_tasks = SnapshotTasks(vec![(u, m)]); + assert_eq!( + serde_json::to_vec(&all_tasks)?, + format!("{{\"{}\":{{\"description\":\"my task\"}}}}", u).into_bytes(), + ); + Ok(()) + } + + #[test] + fn test_round_trip() -> anyhow::Result<()> { + let mut storage = InMemoryStorage::new(); + let version = Uuid::new_v4(); + + let task1 = ( + Uuid::new_v4(), + vec![("description".to_owned(), "one".to_owned())] + .drain(..) + .collect::(), + ); + let task2 = ( + Uuid::new_v4(), + vec![("description".to_owned(), "two".to_owned())] + .drain(..) + .collect::(), + ); + + { + let mut txn = storage.txn()?; + txn.set_task(task1.0, task1.1.clone())?; + txn.set_task(task2.0, task2.1.clone())?; + txn.commit()?; + } + + let snap = { + let mut txn = storage.txn()?; + make_snapshot(txn.as_mut())? + }; + + // apply that snapshot to a fresh bit of fake + let mut storage = InMemoryStorage::new(); + { + let mut txn = storage.txn()?; + apply_snapshot(txn.as_mut(), version, &snap)?; + txn.commit()? + } + + { + let mut txn = storage.txn()?; + assert_eq!(txn.get_task(task1.0)?, Some(task1.1)); + assert_eq!(txn.get_task(task2.0)?, Some(task2.1)); + assert_eq!(txn.all_tasks()?.len(), 2); + assert_eq!(txn.base_version()?, version); + assert_eq!(txn.operations()?.len(), 0); + assert_eq!(txn.get_working_set()?.len(), 1); + } + + Ok(()) + } +}