Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

zcash_client_sqlite: Ensure that all shielded change outputs are correctly flagged. #1585

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 3 additions & 9 deletions zcash_client_sqlite/src/testing/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,26 +147,20 @@ unsafe fn run_sqlite3<S: AsRef<OsStr>>(db_path: S, command: &str) {
eprintln!("------");
}

#[derive(Default)]
pub(crate) struct TestDbFactory {
target_migrations: Option<Vec<Uuid>>,
}

impl TestDbFactory {
#[allow(dead_code)]
pub(crate) fn new(target_migrations: Vec<Uuid>) -> Self {
Self {
target_migrations: Some(target_migrations),
}
}
}

impl Default for TestDbFactory {
fn default() -> Self {
Self {
target_migrations: Default::default(),
}
}
}

impl DataStoreFactory for TestDbFactory {
type Error = ();
type AccountId = AccountId;
Expand All @@ -178,7 +172,7 @@ impl DataStoreFactory for TestDbFactory {
let data_file = NamedTempFile::new().unwrap();
let mut db_data = WalletDb::for_path(data_file.path(), network).unwrap();
if let Some(migrations) = &self.target_migrations {
init_wallet_db_internal(&mut db_data, None, &migrations, true).unwrap();
init_wallet_db_internal(&mut db_data, None, migrations, true).unwrap();
} else {
init_wallet_db(&mut db_data, None).unwrap();
}
Expand Down
4 changes: 4 additions & 0 deletions zcash_client_sqlite/src/wallet/init/migrations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ mod add_utxo_account;
mod addresses_table;
mod ensure_orchard_ua_receiver;
mod ephemeral_addresses;
mod fix_bad_change_flagging;
mod fix_broken_commitment_trees;
mod full_account_ids;
mod initial_setup;
Expand Down Expand Up @@ -80,6 +81,8 @@ pub(super) fn all_migrations<P: consensus::Parameters + 'static>(
// support_legacy_sqlite
// |
// fix_broken_commitment_trees
// |
// fix_bad_change_flagging
vec![
Box::new(initial_setup::Migration {}),
Box::new(utxos_table::Migration {}),
Expand Down Expand Up @@ -141,6 +144,7 @@ pub(super) fn all_migrations<P: consensus::Parameters + 'static>(
Box::new(fix_broken_commitment_trees::Migration {
params: params.clone(),
}),
Box::new(fix_bad_change_flagging::Migration),
]
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
//! Sets the `is_change` flag on output notes received by an internal key when input value was
//! provided from the account corresponding to that key.
use std::collections::HashSet;

use rusqlite::named_params;
use schemerz_rusqlite::RusqliteMigration;
use uuid::Uuid;
use zip32::Scope;

use crate::{
wallet::{
init::{migrations::fix_broken_commitment_trees, WalletMigrationError},
scope_code,
},
SAPLING_TABLES_PREFIX,
};

#[cfg(feature = "orchard")]
use crate::ORCHARD_TABLES_PREFIX;

pub(super) const MIGRATION_ID: Uuid = Uuid::from_u128(0x6d36656d_533b_4b65_ae91_dcb95c4ad289);

const DEPENDENCIES: &[Uuid] = &[fix_broken_commitment_trees::MIGRATION_ID];

pub(super) struct Migration;

impl schemerz::Migration<Uuid> for Migration {
fn id(&self) -> Uuid {
MIGRATION_ID
}

fn dependencies(&self) -> HashSet<Uuid> {
DEPENDENCIES.iter().copied().collect()
}

fn description(&self) -> &'static str {
"Sets the `is_change` flag on output notes received by an internal key when input value was provided from the account corresponding to that key."
}
}

impl RusqliteMigration for Migration {
type Error = WalletMigrationError;

fn up(&self, transaction: &rusqlite::Transaction) -> Result<(), WalletMigrationError> {
let fix_change_flag = |table_prefix| {
transaction.execute(
&format!(
"UPDATE {table_prefix}_received_notes
SET is_change = 1
FROM sent_notes sn
WHERE sn.tx = {table_prefix}_received_notes.tx
AND sn.from_account_id = {table_prefix}_received_notes.account_id
nuttycom marked this conversation as resolved.
Show resolved Hide resolved
AND {table_prefix}_received_notes.recipient_key_scope = :internal_scope"
nuttycom marked this conversation as resolved.
Show resolved Hide resolved
),
named_params! {":internal_scope": scope_code(Scope::Internal)},
)
};

fix_change_flag(SAPLING_TABLES_PREFIX)?;
#[cfg(feature = "orchard")]
fix_change_flag(ORCHARD_TABLES_PREFIX)?;

Ok(())
}

fn down(&self, _: &rusqlite::Transaction) -> Result<(), WalletMigrationError> {
Err(WalletMigrationError::CannotRevert(MIGRATION_ID))
}
}

#[cfg(test)]
mod tests {
use crate::wallet::init::migrations::tests::test_migrate;

#[cfg(feature = "transparent-inputs")]
use {
crate::{
testing::{db::TestDbFactory, BlockCache},
wallet::init::init_wallet_db,
},
zcash_client_backend::{
data_api::{
testing::{
pool::ShieldedPoolTester, sapling::SaplingPoolTester, AddressType, TestBuilder,
},
wallet::input_selection::GreedyInputSelector,
Account as _, WalletRead as _, WalletWrite as _,
},
fees::{standard, DustOutputPolicy},
wallet::WalletTransparentOutput,
},
zcash_primitives::{
block::BlockHash,
transaction::{
components::{OutPoint, TxOut},
fees::StandardFeeRule,
},
},
zcash_protocol::value::Zatoshis,
};

#[test]
fn migrate() {
test_migrate(&[super::MIGRATION_ID]);
}

#[cfg(feature = "transparent-inputs")]
fn shield_transparent<T: ShieldedPoolTester>() {
let ds_factory = TestDbFactory::new(super::DEPENDENCIES.to_vec());
let cache = BlockCache::new();
let mut st = TestBuilder::new()
.with_data_store_factory(ds_factory)
.with_block_cache(cache)
.with_account_from_sapling_activation(BlockHash([0; 32]))
.build();

let account = st.test_account().cloned().unwrap();
let dfvk = T::test_account_fvk(&st);

let uaddr = st
.wallet()
.get_current_address(account.id())
.unwrap()
.unwrap();
let taddr = uaddr.transparent().unwrap();

// Ensure that the wallet has at least one block
let (h, _, _) = st.generate_next_block(
&dfvk,
AddressType::Internal,
Zatoshis::const_from_u64(50000),
);
st.scan_cached_blocks(h, 1);

let utxo = WalletTransparentOutput::from_parts(
OutPoint::fake(),
TxOut {
value: Zatoshis::const_from_u64(100000),
script_pubkey: taddr.script(),
},
Some(h),
)
.unwrap();

let res0 = st.wallet_mut().put_received_transparent_utxo(&utxo);
assert_matches!(res0, Ok(_));

let fee_rule = StandardFeeRule::Zip317;

let input_selector = GreedyInputSelector::new(
standard::SingleOutputChangeStrategy::new(fee_rule, None, T::SHIELDED_PROTOCOL),
DustOutputPolicy::default(),
);

let txids = st
.shield_transparent_funds(
&input_selector,
Zatoshis::from_u64(10000).unwrap(),
account.usk(),
&[*taddr],
1,
)
.unwrap();
assert_eq!(txids.len(), 1);

let tx = st.get_tx_from_history(*txids.first()).unwrap().unwrap();
assert_eq!(tx.spent_note_count(), 1);
assert!(tx.has_change());
assert_eq!(tx.received_note_count(), 0);
assert_eq!(tx.sent_note_count(), 0);
assert!(tx.is_shielding());

// Complete the migration
init_wallet_db(st.wallet_mut().db_mut(), None).unwrap();

// Ensure that the transaction metadata is still correct after the update produced by scanning.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What scanning update? AFAICT this test doesn't do one; it just checks that the unscanned shielding transaction's metadata is not altered by the migration. It might indeed be useful to have two tests for this migration: one that ensures correct data isn't altered, and one that ensures incorrect data gets fixed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy/paste error - this comment was correct in the test for the change from IS_NULL to MAX, and the test is the same.

let tx = st.get_tx_from_history(*txids.first()).unwrap().unwrap();
assert_eq!(tx.spent_note_count(), 1);
assert!(tx.has_change());
assert_eq!(tx.received_note_count(), 0);
assert_eq!(tx.sent_note_count(), 0);
assert!(tx.is_shielding());
}

#[test]
#[cfg(feature = "transparent-inputs")]
fn sapling_shield_transparent() {
shield_transparent::<SaplingPoolTester>();
}

#[test]
#[cfg(all(feature = "orchard", feature = "transparent-inputs"))]
fn orchard_shield_transparent() {
use zcash_client_backend::data_api::testing::orchard::OrchardPoolTester;

shield_transparent::<OrchardPoolTester>();
}
}