Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add Package definition on hugr-core #1587

Merged
merged 3 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion hugr-cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ categories = ["compilers"]
[dependencies]
clap = { workspace = true, features = ["derive"] }
clap-verbosity-flag.workspace = true
hugr-core = { path = "../hugr-core", version = "0.13.1" }
derive_more = { workspace = true, features = ["display", "error", "from"] }
hugr = { path = "../hugr", version = "0.13.1" }
serde_json.workspace = true
serde.workspace = true
thiserror.workspace = true
Expand Down
2 changes: 1 addition & 1 deletion hugr-cli/src/extensions.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Dump standard extensions in serialized form.
use clap::Parser;
use hugr_core::extension::ExtensionRegistry;
use hugr::extension::ExtensionRegistry;
use std::{io::Write, path::PathBuf};

/// Dump the standard extensions.
Expand Down
54 changes: 17 additions & 37 deletions hugr-cli/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@
use clap::Parser;
use clap_verbosity_flag::{InfoLevel, Verbosity};
use clio::Input;
use hugr_core::{Extension, Hugr};
use derive_more::{Display, Error, From};
use hugr::package::{PackageEncodingError, PackageValidationError};
use std::{ffi::OsString, path::PathBuf};
use thiserror::Error;

pub mod extensions;
pub mod mermaid;
pub mod validate;

// TODO: Deprecated re-export. Remove on a breaking release.
pub use hugr::package::Package;

/// CLI arguments.
#[derive(Parser, Debug)]
#[clap(version = "1.0", long_about = None)]
Expand All @@ -30,18 +33,21 @@ pub enum CliArgs {
}

/// Error type for the CLI.
#[derive(Debug, Error)]
#[error(transparent)]
#[derive(Debug, Display, Error, From)]
#[non_exhaustive]
pub enum CliError {
/// Error reading input.
#[error("Error reading from path: {0}")]
InputFile(#[from] std::io::Error),
#[display("Error reading from path: {_0}")]
InputFile(std::io::Error),
/// Error parsing input.
#[error("Error parsing input: {0}")]
Parse(#[from] serde_json::Error),
#[display("Error parsing input: {_0}")]
Parse(serde_json::Error),
/// Error loading a package.
#[display("Error parsing package: {_0}")]
Package(PackageEncodingError),
#[display("Error validating HUGR: {_0}")]
/// Errors produced by the `validate` subcommand.
Validate(#[from] validate::ValError),
Validate(PackageValidationError),
}

/// Validate and visualise a HUGR file.
Expand All @@ -68,36 +74,10 @@ pub struct HugrArgs {
pub extensions: Vec<PathBuf>,
}

#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
/// Package of module HUGRs and extensions.
/// The HUGRs are validated against the extensions.
pub struct Package {
/// Module HUGRs included in the package.
pub modules: Vec<Hugr>,
/// Extensions to validate against.
pub extensions: Vec<Extension>,
}

impl Package {
/// Create a new package.
pub fn new(modules: Vec<Hugr>, extensions: Vec<Extension>) -> Self {
Self {
modules,
extensions,
}
}
}

impl HugrArgs {
/// Read either a package or a single hugr from the input.
pub fn get_package(&mut self) -> Result<Package, CliError> {
let val: serde_json::Value = serde_json::from_reader(&mut self.input)?;
// read either a package or a single hugr
if let Ok(p) = serde_json::from_value::<Package>(val.clone()) {
Ok(p)
} else {
let hugr: Hugr = serde_json::from_value(val)?;
Ok(Package::new(vec![hugr], vec![]))
}
let pkg = Package::from_json_reader(&mut self.input)?;
Ok(pkg)
}
}
2 changes: 1 addition & 1 deletion hugr-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use clap_verbosity_flag::Level;
fn main() {
match CliArgs::parse() {
CliArgs::Validate(args) => run_validate(args),
CliArgs::GenExtensions(args) => args.run_dump(&hugr_core::std_extensions::STD_REG),
CliArgs::GenExtensions(args) => args.run_dump(&hugr::std_extensions::STD_REG),
CliArgs::Mermaid(mut args) => args.run_print().unwrap(),
CliArgs::External(_) => {
// TODO: Implement support for external commands.
Expand Down
2 changes: 1 addition & 1 deletion hugr-cli/src/mermaid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::io::Write;

use clap::Parser;
use clio::Output;
use hugr_core::HugrView;
use hugr::HugrView;

/// Dump the standard extensions.
#[derive(Parser, Debug)]
Expand Down
51 changes: 10 additions & 41 deletions hugr-cli/src/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

use clap::Parser;
use clap_verbosity_flag::Level;
use hugr_core::{extension::ExtensionRegistry, Extension, Hugr};
use thiserror::Error;
use hugr::package::PackageValidationError;
use hugr::{extension::ExtensionRegistry, Extension, Hugr};

use crate::{CliError, HugrArgs, Package};
use crate::{CliError, HugrArgs};

/// Validate and visualise a HUGR file.
#[derive(Parser, Debug)]
Expand All @@ -19,18 +19,6 @@ pub struct ValArgs {
pub hugr_args: HugrArgs,
}

/// Error type for the CLI.
#[derive(Error, Debug)]
#[non_exhaustive]
pub enum ValError {
/// Error validating HUGR.
#[error("Error validating HUGR: {0}")]
Validate(#[from] hugr_core::hugr::ValidationError),
/// Error registering extension.
#[error("Error registering extension: {0}")]
ExtReg(#[from] hugr_core::extension::ExtensionRegistryError),
}

/// String to print when validation is successful.
pub const VALID_PRINT: &str = "HUGR valid!";

Expand All @@ -50,49 +38,30 @@ impl ValArgs {
}
}

impl Package {
/// Validate the package against an extension registry.
///
/// `reg` is updated with any new extensions.
///
/// Returns the validated modules.
pub fn validate(mut self, reg: &mut ExtensionRegistry) -> Result<Vec<Hugr>, ValError> {
// register packed extensions
for ext in self.extensions {
reg.register_updated(ext)?;
}

for hugr in self.modules.iter_mut() {
hugr.update_validate(reg)?;
}

Ok(self.modules)
}
}

impl HugrArgs {
/// Load the package and validate against an extension registry.
///
/// Returns the validated modules and the extension registry the modules
/// were validated against.
pub fn validate(&mut self) -> Result<(Vec<Hugr>, ExtensionRegistry), CliError> {
let package = self.get_package()?;
let mut package = self.get_package()?;

let mut reg: ExtensionRegistry = if self.no_std {
hugr_core::extension::PRELUDE_REGISTRY.to_owned()
hugr::extension::PRELUDE_REGISTRY.to_owned()
} else {
hugr_core::std_extensions::STD_REG.to_owned()
hugr::std_extensions::STD_REG.to_owned()
};

// register external extensions
for ext in &self.extensions {
let f = std::fs::File::open(ext)?;
let ext: Extension = serde_json::from_reader(f)?;
reg.register_updated(ext).map_err(ValError::ExtReg)?;
reg.register_updated(ext)
.map_err(PackageValidationError::Extension)?;
}

let modules = package.validate(&mut reg)?;
Ok((modules, reg))
package.validate(&mut reg)?;
Ok((package.modules, reg))
}

/// Test whether a `level` message should be output.
Expand Down
14 changes: 7 additions & 7 deletions hugr-cli/tests/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,17 @@

use assert_cmd::Command;
use assert_fs::{fixture::FileWriteStr, NamedTempFile};
use hugr_cli::{validate::VALID_PRINT, Package};
use hugr_core::builder::DFGBuilder;
use hugr_core::types::Type;
use hugr_core::{
use hugr::builder::DFGBuilder;
use hugr::types::Type;
use hugr::{
builder::{Container, Dataflow},
extension::prelude::{BOOL_T, QB_T},
std_extensions::arithmetic::float_types::FLOAT64_TYPE,
type_row,
types::Signature,
Hugr,
};
use hugr_cli::{validate::VALID_PRINT, Package};
use predicates::{prelude::*, str::contains};
use rstest::{fixture, rstest};

Expand Down Expand Up @@ -128,7 +128,7 @@ fn test_bad_json(mut val_cmd: Command) {
val_cmd
.assert()
.failure()
.stderr(contains("Error parsing input"));
.stderr(contains("Error parsing package"));
}

#[rstest]
Expand All @@ -139,7 +139,7 @@ fn test_bad_json_silent(mut val_cmd: Command) {
val_cmd
.assert()
.failure()
.stderr(contains("Error parsing input").not());
.stderr(contains("Error parsing package").not());
}

#[rstest]
Expand Down Expand Up @@ -188,7 +188,7 @@ fn test_float_extension(float_hugr_string: String, mut val_cmd: Command) {
#[fixture]
fn package_string(#[with(FLOAT64_TYPE)] test_hugr: Hugr) -> String {
let rdr = std::fs::File::open(FLOAT_EXT_FILE).unwrap();
let float_ext: hugr_core::Extension = serde_json::from_reader(rdr).unwrap();
let float_ext: hugr::Extension = serde_json::from_reader(rdr).unwrap();
let package = Package::new(vec![test_hugr], vec![float_ext]);
serde_json::to_string(&package).unwrap()
}
Expand Down
2 changes: 1 addition & 1 deletion hugr-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ serde = { workspace = true, features = ["derive", "rc"] }
serde_yaml = { workspace = true, optional = true }
typetag = { workspace = true }
smol_str = { workspace = true, features = ["serde"] }
derive_more = { workspace = true, features = ["display", "from"] }
derive_more = { workspace = true, features = ["display", "error", "from"] }
itertools = { workspace = true }
html-escape = { workspace = true }
bitvec = { workspace = true, features = ["serde"] }
Expand Down
42 changes: 39 additions & 3 deletions hugr-core/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ pub use prelude::{PRELUDE, PRELUDE_REGISTRY};
pub mod declarative;

/// Extension Registries store extensions to be looked up e.g. during validation.
#[derive(Clone, Debug)]
#[derive(Clone, Debug, PartialEq)]
pub struct ExtensionRegistry(BTreeMap<ExtensionId, Extension>);

impl ExtensionRegistry {
Expand Down Expand Up @@ -92,6 +92,9 @@ impl ExtensionRegistry {
/// If extension IDs match, the extension with the higher version is kept.
/// If versions match, the original extension is kept.
/// Returns a reference to the registered extension if successful.
///
/// Avoids cloning the extension unless required. For a reference version see
/// [`ExtensionRegistry::register_updated_ref`].
pub fn register_updated(
&mut self,
extension: Extension,
Expand All @@ -107,6 +110,30 @@ impl ExtensionRegistry {
}
}

/// Registers a new extension to the registry, keeping most up to date if
/// extension exists.
///
/// If extension IDs match, the extension with the higher version is kept.
/// If versions match, the original extension is kept. Returns a reference
/// to the registered extension if successful.
///
/// Clones the extension if required. For no-cloning version see
/// [`ExtensionRegistry::register_updated`].
pub fn register_updated_ref(
&mut self,
extension: &Extension,
) -> Result<&Extension, ExtensionRegistryError> {
match self.0.entry(extension.name().clone()) {
btree_map::Entry::Occupied(mut prev) => {
if prev.get().version() < extension.version() {
*prev.get_mut() = extension.clone();
}
Ok(prev.into_mut())
Comment on lines +127 to +131
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesn't appear to be tested

}
btree_map::Entry::Vacant(ve) => Ok(ve.insert(extension.clone())),
}
}

/// Returns the number of extensions in the registry.
pub fn len(&self) -> usize {
self.0.len()
Expand Down Expand Up @@ -418,7 +445,7 @@ impl Extension {

impl PartialEq for Extension {
fn eq(&self, other: &Self) -> bool {
self.name == other.name
self.name == other.name && self.version == other.version
}
}

Expand Down Expand Up @@ -612,7 +639,11 @@ pub mod test {

#[test]
fn test_register_update() {
// Two registers that should remain the same.
// We use them to test both `register_updated` and `register_updated_ref`.
let mut reg = ExtensionRegistry::try_new([]).unwrap();
let mut reg_ref = ExtensionRegistry::try_new([]).unwrap();

let ext_1_id = ExtensionId::new("ext1").unwrap();
let ext_2_id = ExtensionId::new("ext2").unwrap();
let ext1 = Extension::new(ext_1_id.clone(), Version::new(1, 0, 0));
Expand All @@ -621,7 +652,8 @@ pub mod test {
let ext2 = Extension::new(ext_2_id, Version::new(1, 0, 0));

reg.register(ext1.clone()).unwrap();
assert_eq!(reg.get("ext1").unwrap().version(), &Version::new(1, 0, 0));
reg_ref.register(ext1.clone()).unwrap();
assert_eq!(&reg, &reg_ref);

// normal registration fails
assert_eq!(
Expand All @@ -634,12 +666,16 @@ pub mod test {
);

// register with update works
reg_ref.register_updated_ref(&ext1_1).unwrap();
reg.register_updated(ext1_1.clone()).unwrap();
assert_eq!(reg.get("ext1").unwrap().version(), &Version::new(1, 1, 0));
assert_eq!(&reg, &reg_ref);

// register with lower version does not change version
reg_ref.register_updated_ref(&ext1_2).unwrap();
reg.register_updated(ext1_2.clone()).unwrap();
assert_eq!(reg.get("ext1").unwrap().version(), &Version::new(1, 1, 0));
assert_eq!(&reg, &reg_ref);

reg.register(ext2.clone()).unwrap();
assert_eq!(reg.get("ext2").unwrap().version(), &Version::new(1, 0, 0));
Expand Down
1 change: 1 addition & 0 deletions hugr-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ pub mod hugr;
pub mod import;
pub mod macros;
pub mod ops;
pub mod package;
pub mod std_extensions;
pub mod types;
pub mod utils;
Expand Down
Loading
Loading