diff --git a/src/adapter/memory_adapter.rs b/src/adapter/memory_adapter.rs index e1e7fb63..7e97401b 100644 --- a/src/adapter/memory_adapter.rs +++ b/src/adapter/memory_adapter.rs @@ -1,18 +1,52 @@ use crate::{ adapter::{Adapter, Filter}, model::Model, + util::parse_csv_line, Result, }; use async_trait::async_trait; use hashlink::LinkedHashSet; +use super::StringAdapter; + #[derive(Default)] pub struct MemoryAdapter { policy: LinkedHashSet>, is_filtered: bool, } +impl From for MemoryAdapter { + fn from(string_adatpter: StringAdapter) -> Self { + let string_policies = string_adatpter.policy.split("\n"); + let mut memory_adapter = Self { + policy: LinkedHashSet::new(), + is_filtered: false, + }; + for line in string_policies { + if let Some(tokens) = parse_csv_line(line) { + let ptype = tokens[0].clone(); + if let Some(sec) = ptype.chars().next().map(|x| x.to_string()) { + let mut rule = tokens[1..].to_vec(); + rule.insert(0, ptype); + rule.insert(0, sec); + memory_adapter.policy.insert(rule); + } + } + } + memory_adapter + } +} + +#[allow(clippy::should_implement_trait)] +impl MemoryAdapter { + pub fn from_str(s: impl ToString) -> Self { + let s = s.to_string(); + let string_adapter = StringAdapter::new(s); + Self::from(string_adapter) + } +} + #[async_trait] impl Adapter for MemoryAdapter { async fn load_policy(&mut self, m: &mut dyn Model) -> Result<()> { @@ -232,3 +266,132 @@ impl Adapter for MemoryAdapter { self.is_filtered } } + +#[cfg(test)] +mod test { + use hashlink::LinkedHashSet; + #[cfg(target_arch = "wasm32")] + use wasm_bindgen_test::*; + + use crate::{ + adapter::StringAdapter, Adapter, CoreApi, DefaultModel, Enforcer, + Filter, MemoryAdapter, + }; + + #[cfg_attr( + all(not(target_arch = "wasm32"), feature = "runtime-async-std"), + async_std::test + )] + #[cfg_attr( + all(not(target_arch = "wasm32"), feature = "runtime-tokio"), + tokio::test + )] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + async fn test_load_policy() { + let policy = "p, alice, data1, read\np, bob, data2, write"; + let mut adapter = MemoryAdapter::from_str(policy); + let mut model = DefaultModel::from_str(include_str!( + "../../examples/rbac_model.conf" + )) + .await + .unwrap(); + + adapter.load_policy(&mut model).await.unwrap(); + let enforcer = Enforcer::new(model, adapter).await.unwrap(); + + assert!(enforcer.enforce(("alice", "data1", "read")).unwrap()); + assert!(enforcer.enforce(("bob", "data2", "write")).unwrap()); + assert!(!enforcer.enforce(("alice", "data2", "read")).unwrap()); + } + + #[cfg_attr( + all(not(target_arch = "wasm32"), feature = "runtime-async-std"), + async_std::test + )] + #[cfg_attr( + all(not(target_arch = "wasm32"), feature = "runtime-tokio"), + tokio::test + )] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + async fn test_save_policy() { + let policy = "p, alice, data1, read\np, bob, data2, write"; + let mut adapter = MemoryAdapter::from_str(policy); + let mut model = DefaultModel::from_str(include_str!( + "../../examples/rbac_model.conf" + )) + .await + .unwrap(); + + adapter.load_policy(&mut model).await.unwrap(); + adapter.save_policy(&mut model).await.unwrap(); + + let mut expected = LinkedHashSet::new(); + expected.insert( + vec!["p", "p", "alice", "data1", "read"] + .iter() + .map(|s| s.to_string()) + .collect::>(), + ); + expected.insert( + vec!["p", "p", "bob", "data2", "write"] + .iter() + .map(|s| s.to_string()) + .collect::>(), + ); + + assert_eq!(adapter.policy, expected); + } + + #[cfg_attr( + all(not(target_arch = "wasm32"), feature = "runtime-async-std"), + async_std::test + )] + #[cfg_attr( + all(not(target_arch = "wasm32"), feature = "runtime-tokio"), + tokio::test + )] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + async fn test_compare_string_adapter() { + let policy = "p, alice, data1, read\np, bob, data2, write"; + let mut string_adapter = StringAdapter::new(policy); + let mut memory_adapter = MemoryAdapter::from_str(policy); + let mut model = DefaultModel::from_str(include_str!( + "../../examples/rbac_model.conf" + )) + .await + .unwrap(); + + assert_eq!( + string_adapter.load_policy(&mut model).await.unwrap(), + memory_adapter.load_policy(&mut model).await.unwrap() + ); + + let filter = Filter { + p: vec!["alice"], + g: vec![], + }; + + assert_eq!( + string_adapter + .load_filtered_policy(&mut model, filter.clone()) + .await + .unwrap(), + memory_adapter + .load_filtered_policy(&mut model, filter) + .await + .unwrap(), + ); + + assert_eq!(string_adapter.is_filtered(), memory_adapter.is_filtered()); + + let string_enforcer = + Enforcer::new(model.clone(), string_adapter).await.unwrap(); + let memory_enforcer = + Enforcer::new(model.clone(), memory_adapter).await.unwrap(); + + assert_eq!( + string_enforcer.enforce(("alice", "data1", "read")).unwrap(), + memory_enforcer.enforce(("alice", "data1", "read")).unwrap() + ); + } +}