Skip to content

Add stage ids for DynPipeline and ability to retrieve stages with concrete type #320

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ ctrlc = { version = "3.4.5", default-features = false, features = [] }
doxygen-rs = { version = "0.4.2", default-features = false, features = [] }
dyn-iter = { version = "1.0.1", default-features = false, features = [] }
etherparse = { version = "0.17.0", default-features = false, features = [] }
ordermap = { version = "0.5.5", default-features = false, features = [] }
ipnet = { version = "2.11.0", default-features = false, features = [] }
iptrie = { version = "0.10.2", default-features = false, features = [] }
mio = { version = "1.0.3", default-features = false, features = [] }
Expand Down
2 changes: 2 additions & 0 deletions dataplane/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@ license = "Apache-2.0"
clap = { workspace = true, features = ["derive"] }
ctrlc = { workspace = true, features = ["termination"] }
dpdk = { workspace = true }
id = { workspace = true }
iptrie = { workspace = true }
dyn-iter = { workspace = true }
net = { workspace = true, features = ["serde"] }
ordermap = { workspace = true, features = ["std"] }
routing = { workspace = true }
serde = { workspace = true, features = ["derive"] }
serde_yml = { workspace = true }
Expand Down
1 change: 0 additions & 1 deletion dataplane/src/nat/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
// SPDX-License-Identifier: Apache-2.0
// Copyright Open Network Fabric Authors

#![allow(dead_code)]
#![allow(rustdoc::private_doc_tests)]

Expand Down
6 changes: 5 additions & 1 deletion dataplane/src/pipeline/dyn_nf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ pub trait DynNetworkFunction<Buf: PacketBufferMut>: Any {
fn process_dyn<'a>(&'a mut self, input: DynIter<'a, Packet<Buf>>) -> DynIter<'a, Packet<Buf>>;
}

struct DynNetworkFunctionImpl<Buf: PacketBufferMut, NF: NetworkFunction<Buf> + 'static> {
pub(crate) struct DynNetworkFunctionImpl<Buf: PacketBufferMut, NF: NetworkFunction<Buf> + 'static> {
nf: NF,
_marker: PhantomData<Buf>,
}
Expand All @@ -44,6 +44,10 @@ impl<Buf: PacketBufferMut, NF: NetworkFunction<Buf>> DynNetworkFunctionImpl<Buf,
_marker: PhantomData,
}
}

pub fn get_nf(&self) -> &NF {
&self.nf
}
}

/// Creates a boxed, dynamic network function.
Expand Down
2 changes: 1 addition & 1 deletion dataplane/src/pipeline/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ pub(crate) mod test_utils;
#[allow(unused)]
pub use dyn_nf::{DynNetworkFunction, nf_dyn};
#[allow(unused)]
pub use pipeline::DynPipeline;
pub use pipeline::{DynPipeline, StageId};
#[allow(unused)]
pub use static_nf::{NetworkFunction, StaticChain};

Expand Down
174 changes: 168 additions & 6 deletions dataplane/src/pipeline/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,17 @@
// Copyright Open Network Fabric Authors

use dyn_iter::{DynIter, IntoDynIterator};
use id::Id;
use net::buffer::PacketBufferMut;
use ordermap::OrderMap;
use std::any::Any;

use crate::packet::Packet;
use crate::pipeline::dyn_nf::DynNetworkFunctionImpl;
use crate::pipeline::{DynNetworkFunction, NetworkFunction, nf_dyn};

pub type StageId<Buf> = Id<Box<dyn DynNetworkFunction<Buf>>>;

/// A dynamic pipeline that can be updated at runtime.
///
/// This struct is used to create a dynamic pipeline that can be updated at runtime.
Expand All @@ -16,13 +22,21 @@ use crate::pipeline::{DynNetworkFunction, NetworkFunction, nf_dyn};
/// [`DynNetworkFunction`]
#[derive(Default)]
pub struct DynPipeline<Buf: PacketBufferMut> {
nfs: Vec<Box<dyn DynNetworkFunction<Buf>>>,
nfs: OrderMap<StageId<Buf>, Box<dyn DynNetworkFunction<Buf>>>,
}

#[derive(Debug, thiserror::Error)]
pub enum PipelineError {
#[error("Duplicate stage id: {0}")]
DuplicateStageId(String),
}

impl<Buf: PacketBufferMut + 'static> DynPipeline<Buf> {
impl<Buf: PacketBufferMut> DynPipeline<Buf> {
#[allow(unused)]
pub fn new() -> Self {
Self { nfs: vec![] }
Self {
nfs: OrderMap::new(),
}
}

/// Add a static network function to the pipeline.
Expand All @@ -34,6 +48,19 @@ impl<Buf: PacketBufferMut + 'static> DynPipeline<Buf> {
self.add_stage_dyn(nf_dyn(nf))
}

/// Add a static network function to the pipeline using a specific stage id.
///
/// This method takes a [`NetworkFunction`] and adds it to the pipeline.
///
#[allow(unused)]
pub fn add_stage_with_id<NF: NetworkFunction<Buf> + 'static>(
&mut self,
id: StageId<Buf>,
nf: NF,
) -> Result<&mut Self, PipelineError> {
self.add_stage_dyn_with_id(id, nf_dyn(nf))
}

/// Add a dynamic network function to the pipeline.
///
/// This method takes a [`DynNetworkFunction`] and adds it to the pipeline.
Expand All @@ -44,15 +71,100 @@ impl<Buf: PacketBufferMut + 'static> DynPipeline<Buf> {
/// [`nf_dyn`]
#[allow(unused)]
pub fn add_stage_dyn(mut self, nf: Box<dyn DynNetworkFunction<Buf>>) -> Self {
self.nfs.push(nf);
self.internal_add_stage_dyn_with_id(StageId::<Buf>::new(), nf);
self
}

/// Add a dynamic network function to the pipeline using a specific stage id.
///
/// This method takes a [`DynNetworkFunction`] and adds it to the pipeline.
///
/// # See Also
///
/// [`DynNetworkFunction`]
/// [`nf_dyn`]
/// Add a dynamic network function to the pipeline using a specific stage id.
///
/// This method takes a [`DynNetworkFunction`] and adds it to the pipeline.
///
/// # See Also
///
/// [`DynNetworkFunction`]
/// [`nf_dyn`]
/// Add a dynamic network function to the pipeline using a specific stage id.
///
/// This method takes a [`DynNetworkFunction`] and adds it to the pipeline.
///
/// # See Also
///
/// [`DynNetworkFunction`]
/// [`nf_dyn`]
pub fn add_stage_dyn_with_id(
&mut self,
id: StageId<Buf>,
nf: Box<dyn DynNetworkFunction<Buf>>,
) -> Result<&mut Self, PipelineError> {
self.internal_add_stage_dyn_with_id(id, nf)
}

fn internal_add_stage_dyn_with_id(
&mut self,
id: StageId<Buf>,
nf: Box<dyn DynNetworkFunction<Buf>>,
) -> Result<&mut Self, PipelineError> {
// FIXME(mvachhar): There seems to be no method to insert and error if the key already exists.
// As a result, this does a double hash and lookup. Probably fine here, but may need to submit
// a patch to ordermap to add this functionality in other places.
//
// When [this](https://github.com/rust-lang/rust/issues/82766) becomes stable, we should move
// to using `try_insert` instead of `get` then `insert`.
if self.nfs.get(&id).is_some() {
Err(PipelineError::DuplicateStageId(id.to_string()))
} else {
self.nfs.insert(id, nf);
Ok(self)
}
}

/// Get a static network function from the pipeline by stage id.
/// Get a dynamic network function from the pipeline by stage id.
///
/// This method takes a stage id and returns the [`DynNetworkFunction`] associated with that stage.
///
/// # See Also
///
///
/// This method takes a stage id and returns the [`NetworkFunction`] associated with that stage.
///
/// # See Also
///
#[allow(unused)]
pub fn get_stage_by_id<T: NetworkFunction<Buf> + 'static>(
&self,
id: &StageId<Buf>,
) -> Option<&T> {
self.get_stage_dyn_by_id::<DynNetworkFunctionImpl<Buf, T>>(id)
.map(DynNetworkFunctionImpl::get_nf)
}

/// Get a dynamic network function from the pipeline by stage id.
///
/// This method takes a stage id and returns the [`DynNetworkFunction`] associated with that stage.
///
/// # See Also
///
#[allow(unused)]
pub fn get_stage_dyn_by_id<T: DynNetworkFunction<Buf>>(&self, id: &StageId<Buf>) -> Option<&T> {
self.nfs
.get(id)
.and_then(|nf| (&**nf as &dyn Any).downcast_ref::<T>())
}
}

impl<Buf: PacketBufferMut> DynNetworkFunction<Buf> for DynPipeline<Buf> {
fn process_dyn<'a>(&'a mut self, input: DynIter<'a, Packet<Buf>>) -> DynIter<'a, Packet<Buf>> {
self.nfs
.iter_mut()
.values_mut()
.fold(input, move |input, nf| nf.process_dyn(input))
.into_dyn_iter()
}
Expand All @@ -71,11 +183,16 @@ impl<Buf: PacketBufferMut> NetworkFunction<Buf> for DynPipeline<Buf> {
#[cfg(test)]
mod test {
use dyn_iter::IntoDynIterator;
use net::buffer::TestBuffer;
use net::eth::mac::{DestinationMac, Mac};
use net::headers::{Net, TryEth, TryIp, TryIpv4};

use crate::pipeline::dyn_nf::DynNetworkFunctionImpl;
use crate::pipeline::sample_nfs::DecrementTtl;
use crate::pipeline::test_utils::{DynStageGenerator, build_test_ipv4_packet};
use crate::pipeline::{DynNetworkFunction, DynPipeline, NetworkFunction};
use crate::pipeline::{DynNetworkFunction, DynPipeline, NetworkFunction, StageId};

type TestStageId = StageId<TestBuffer>;

#[test]
fn long_dyn_pipeline() {
Expand Down Expand Up @@ -158,4 +275,49 @@ mod test {
panic!("Expected IPv4 packet");
}
}

#[test]
fn get_stage_by_id() {
let mut pipeline = DynPipeline::new();
let mut stages = DynStageGenerator::new();
let num_stages = 10u16;
let test_stage_id = TestStageId::new();

for i in 0..num_stages {
if i == 5 {
pipeline
.add_stage_with_id(test_stage_id, DecrementTtl)
.unwrap();
} else {
pipeline = pipeline.add_stage_dyn(stages.next().unwrap());
}
}

let stage = pipeline.get_stage_by_id::<DecrementTtl>(&test_stage_id);
assert!(stage.is_some());
}

#[test]
fn get_stage_dyn_by_id() {
let mut pipeline = DynPipeline::new();
let mut stages = DynStageGenerator::new();
let num_stages = 10u16;
let test_stage_id = TestStageId::new();

for i in 0..num_stages {
if i == 5 {
pipeline
.add_stage_with_id(test_stage_id, DecrementTtl)
.unwrap();
} else {
pipeline = pipeline.add_stage_dyn(stages.next().unwrap());
}
}

let stage = pipeline
.get_stage_dyn_by_id::<DynNetworkFunctionImpl<TestBuffer, DecrementTtl>>(
&test_stage_id,
);
assert!(stage.is_some());
}
}
Loading