Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Ensure packages always have modules at the root #1589

Merged
merged 8 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 53 additions & 8 deletions hugr-cli/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ use clap::Parser;
use clap_verbosity_flag::{InfoLevel, Verbosity};
use clio::Input;
use derive_more::{Display, Error, From};
use hugr::package::{PackageEncodingError, PackageValidationError};
use hugr::extension::ExtensionRegistry;
use hugr::package::PackageValidationError;
use hugr::Hugr;
use std::{ffi::OsString, path::PathBuf};

pub mod extensions;
Expand Down Expand Up @@ -40,11 +42,8 @@ pub enum CliError {
#[display("Error reading from path: {_0}")]
InputFile(std::io::Error),
/// Error parsing input.
#[display("Error parsing input: {_0}")]
Parse(serde_json::Error),
/// Error loading a package.
#[display("Error parsing package: {_0}")]
Package(PackageEncodingError),
Parse(serde_json::Error),
#[display("Error validating HUGR: {_0}")]
/// Errors produced by the `validate` subcommand.
Validate(PackageValidationError),
Expand Down Expand Up @@ -74,10 +73,56 @@ pub struct HugrArgs {
pub extensions: Vec<PathBuf>,
}

/// A simple enum containing either a package or a single hugr.
///
/// This is required since `Package`s can only contain module-rooted hugrs.
#[derive(Debug, Clone, PartialEq)]
pub enum PackageOrHugr {
/// A package with module-rooted HUGRs and some required extensions.
Package(Package),
/// An arbitrary HUGR.
Hugr(Hugr),
}

impl PackageOrHugr {
/// Returns the slice of hugrs in the package.
pub fn hugrs(&self) -> &[Hugr] {
match self {
PackageOrHugr::Package(pkg) => &pkg.modules,
PackageOrHugr::Hugr(hugr) => std::slice::from_ref(hugr),
}
}

/// Returns the list of hugrs in the package.
pub fn into_hugrs(self) -> Vec<Hugr> {
match self {
PackageOrHugr::Package(pkg) => pkg.modules,
PackageOrHugr::Hugr(hugr) => vec![hugr],
}
}

/// Validates the package or hugr.
///
/// Updates the extension registry with any new extensions defined in the package.
pub fn update_validate(
&mut self,
reg: &mut ExtensionRegistry,
) -> Result<(), PackageValidationError> {
match self {
PackageOrHugr::Package(pkg) => pkg.validate(reg),
PackageOrHugr::Hugr(hugr) => hugr.update_validate(reg).map_err(Into::into),
}
}
}

impl HugrArgs {
/// Read either a package or a single hugr from the input.
pub fn get_package(&mut self) -> Result<Package, CliError> {
let pkg = Package::from_json_reader(&mut self.input)?;
Ok(pkg)
pub fn get_package_or_hugr(&mut self) -> Result<PackageOrHugr, CliError> {
let val: serde_json::Value = serde_json::from_reader(&mut self.input)?;
if let Ok(hugr) = serde_json::from_value::<Hugr>(val.clone()) {
return Ok(PackageOrHugr::Hugr(hugr));
}
let pkg = serde_json::from_value::<Package>(val.clone())?;
Ok(PackageOrHugr::Package(pkg))
}
}
2 changes: 1 addition & 1 deletion hugr-cli/src/mermaid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ impl MermaidArgs {
let hugrs = if self.validate {
self.hugr_args.validate()?.0
} else {
self.hugr_args.get_package()?.modules
self.hugr_args.get_package_or_hugr()?.into_hugrs()
};

for hugr in hugrs {
Expand Down
6 changes: 3 additions & 3 deletions hugr-cli/src/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ impl HugrArgs {
/// Returns the validated modules and the extension registry the modules
/// were validated against.
pub fn validate(&mut self) -> Result<(Vec<Hugr>, ExtensionRegistry), CliError> {
let mut package = self.get_package()?;
let mut package = self.get_package_or_hugr()?;

let mut reg: ExtensionRegistry = if self.no_std {
hugr::extension::PRELUDE_REGISTRY.to_owned()
Expand All @@ -60,8 +60,8 @@ impl HugrArgs {
.map_err(PackageValidationError::Extension)?;
}

package.validate(&mut reg)?;
Ok((package.modules, reg))
package.update_validate(&mut reg)?;
Ok((package.into_hugrs(), reg))
}

/// Test whether a `level` message should be output.
Expand Down
40 changes: 27 additions & 13 deletions hugr-cli/tests/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

use assert_cmd::Command;
use assert_fs::{fixture::FileWriteStr, NamedTempFile};
use hugr::builder::DFGBuilder;
use hugr::builder::{DFGBuilder, DataflowSubContainer, ModuleBuilder};
use hugr::types::Type;
use hugr::{
builder::{Container, Dataflow},
Expand All @@ -31,6 +31,29 @@ fn val_cmd(mut cmd: Command) -> Command {
cmd
}

// path to the fully serialized float extension
const FLOAT_EXT_FILE: &str = concat!(
env!("CARGO_MANIFEST_DIR"),
"/../specification/std_extensions/arithmetic/float/types.json"
);

/// A test package, containing a module-rooted HUGR.
#[fixture]
fn test_package(#[default(BOOL_T)] id_type: Type) -> Package {
let mut module = ModuleBuilder::new();
let df = module
.define_function("test", Signature::new_endo(id_type))
.unwrap();
let [i] = df.input_wires_arr();
df.finish_with_outputs([i]).unwrap();
let hugr = module.hugr().clone(); // unvalidated

let rdr = std::fs::File::open(FLOAT_EXT_FILE).unwrap();
let float_ext: hugr::Extension = serde_json::from_reader(rdr).unwrap();
Package::new(vec![hugr], vec![float_ext]).unwrap()
}

/// A DFG-rooted HUGR.
#[fixture]
fn test_hugr(#[default(BOOL_T)] id_type: Type) -> Hugr {
let mut df = DFGBuilder::new(Signature::new_endo(id_type)).unwrap();
Expand Down Expand Up @@ -169,12 +192,6 @@ fn test_no_std_fail(float_hugr_string: String, mut val_cmd: Command) {
.stderr(contains(" Extension 'arithmetic.float.types' not found"));
}

// path to the fully serialized float extension
const FLOAT_EXT_FILE: &str = concat!(
env!("CARGO_MANIFEST_DIR"),
"/../specification/std_extensions/arithmetic/float/types.json"
);

#[rstest]
fn test_float_extension(float_hugr_string: String, mut val_cmd: Command) {
val_cmd.write_stdin(float_hugr_string);
Expand All @@ -186,15 +203,12 @@ fn test_float_extension(float_hugr_string: String, mut val_cmd: Command) {
val_cmd.assert().success().stderr(contains(VALID_PRINT));
}
#[fixture]
fn package_string(#[with(FLOAT64_TYPE)] test_hugr: Hugr) -> String {
let rdr = std::fs::File::open(FLOAT_EXT_FILE).unwrap();
let float_ext: hugr::Extension = serde_json::from_reader(rdr).unwrap();
let package = Package::new(vec![test_hugr], vec![float_ext]);
serde_json::to_string(&package).unwrap()
fn package_string(#[with(FLOAT64_TYPE)] test_package: Package) -> String {
serde_json::to_string(&test_package).unwrap()
}

#[rstest]
fn test_package(package_string: String, mut val_cmd: Command) {
fn test_package_validation(package_string: String, mut val_cmd: Command) {
// package with float extension and hugr that uses floats can validate
val_cmd.write_stdin(package_string);
val_cmd.arg("-");
Expand Down
Loading
Loading