Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JP-3677: Add maximum_shower_amplitude parameter to jump step #306

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions changes/307.jump.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Add maximum_shower_amplitude parameter to MIRI cosmic rays showers routine
to fix accidental flagging of bright science pixels.
168 changes: 88 additions & 80 deletions src/stcal/jump/jump.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
min_diffs_single_pass=10,
mask_persist_grps_next_int=True,
persist_grps_flagged=25,
max_shower_amplitude=12
):
"""
This is the high-level controlling routine for the jump detection process.
Expand Down Expand Up @@ -220,6 +221,8 @@
then all differences are processed at once.
min_diffs_single_pass : int
The minimum number of groups to switch to flagging all outliers in a single pass.
max_shower_amplitude : float
The maximum possible amplitude for flagged MIRI showers in DN/group

Returns
-------
Expand Down Expand Up @@ -298,46 +301,7 @@
dqflags['DO_NOT_USE']
gdq[gdq == np.bitwise_or(dqflags['SATURATED'], dqflags['JUMP_DET'])] = \
dqflags['SATURATED']
# This is the flag that controls the flagging of snowballs.
if expand_large_events:
gdq, total_snowballs = flag_large_events(
gdq,
jump_flag,
sat_flag,
min_sat_area=min_sat_area,
min_jump_area=min_jump_area,
expand_factor=expand_factor,
sat_required_snowball=sat_required_snowball,
min_sat_radius_extend=min_sat_radius_extend,
edge_size=edge_size,
sat_expand=sat_expand,
max_extended_radius=max_extended_radius,
mask_persist_grps_next_int=mask_persist_grps_next_int,
persist_grps_flagged=persist_grps_flagged,
)
log.info("Total snowballs = %i", total_snowballs)
number_extended_events = total_snowballs
if find_showers:
gdq, num_showers = find_faint_extended(
data,
gdq,
pdq,
readnoise_2d,
frames_per_group,
minimum_sigclip_groups,
dqflags,
snr_threshold=extend_snr_threshold,
min_shower_area=extend_min_area,
inner=extend_inner_radius,
outer=extend_outer_radius,
sat_flag=sat_flag,
jump_flag=jump_flag,
ellipse_expand=extend_ellipse_expand_ratio,
num_grps_masked=grps_masked_after_shower,
max_extended_radius=max_extended_radius,
)
log.info("Total showers= %i", num_showers)
number_extended_events = num_showers

else:
yinc = int(n_rows // n_slices)
slices = []
Expand Down Expand Up @@ -463,46 +427,50 @@
gdq[gdq == np.bitwise_or(dqflags['SATURATED'], dqflags['JUMP_DET'])] = \
dqflags['SATURATED']

# This is the flag that controls the flagging of snowballs.
if expand_large_events:
gdq, total_snowballs = flag_large_events(
gdq,
jump_flag,
sat_flag,
min_sat_area=min_sat_area,
min_jump_area=min_jump_area,
expand_factor=expand_factor,
sat_required_snowball=sat_required_snowball,
min_sat_radius_extend=min_sat_radius_extend,
edge_size=edge_size,
sat_expand=sat_expand,
max_extended_radius=max_extended_radius,
mask_persist_grps_next_int=mask_persist_grps_next_int,
persist_grps_flagged=persist_grps_flagged,
)
log.info("Total snowballs = %i", total_snowballs)
number_extended_events = total_snowballs
if find_showers:
gdq, num_showers = find_faint_extended(
data,
gdq,
pdq,
readnoise_2d,
frames_per_group,
minimum_sigclip_groups,
dqflags,
snr_threshold=extend_snr_threshold,
min_shower_area=extend_min_area,
inner=extend_inner_radius,
outer=extend_outer_radius,
sat_flag=sat_flag,
jump_flag=jump_flag,
ellipse_expand=extend_ellipse_expand_ratio,
num_grps_masked=grps_masked_after_shower,
max_extended_radius=max_extended_radius,
)
log.info("Total showers= %i", num_showers)
number_extended_events = num_showers
# Look for snowballs in near-IR data
if expand_large_events:
gdq, total_snowballs = flag_large_events(

Check warning on line 432 in src/stcal/jump/jump.py

View check run for this annotation

Codecov / codecov/patch

src/stcal/jump/jump.py#L432

Added line #L432 was not covered by tests
gdq,
jump_flag,
sat_flag,
min_sat_area=min_sat_area,
min_jump_area=min_jump_area,
expand_factor=expand_factor,
sat_required_snowball=sat_required_snowball,
min_sat_radius_extend=min_sat_radius_extend,
edge_size=edge_size,
sat_expand=sat_expand,
max_extended_radius=max_extended_radius,
mask_persist_grps_next_int=mask_persist_grps_next_int,
persist_grps_flagged=persist_grps_flagged,
)
log.info("Total snowballs = %i", total_snowballs)
number_extended_events = total_snowballs

Check warning on line 448 in src/stcal/jump/jump.py

View check run for this annotation

Codecov / codecov/patch

src/stcal/jump/jump.py#L447-L448

Added lines #L447 - L448 were not covered by tests

# Look for showers in mid-IR data
if find_showers:
gdq, num_showers = find_faint_extended(

Check warning on line 452 in src/stcal/jump/jump.py

View check run for this annotation

Codecov / codecov/patch

src/stcal/jump/jump.py#L452

Added line #L452 was not covered by tests
data,
gdq,
pdq,
readnoise_2d,
frames_per_group,
minimum_sigclip_groups,
dqflags,
snr_threshold=extend_snr_threshold,
min_shower_area=extend_min_area,
inner=extend_inner_radius,
outer=extend_outer_radius,
sat_flag=sat_flag,
jump_flag=jump_flag,
ellipse_expand=extend_ellipse_expand_ratio,
num_grps_masked=grps_masked_after_shower,
max_extended_radius=max_extended_radius,
max_shower_amplitude=max_shower_amplitude
)
log.info("Total showers= %i", num_showers)
number_extended_events = num_showers

Check warning on line 472 in src/stcal/jump/jump.py

View check run for this annotation

Codecov / codecov/patch

src/stcal/jump/jump.py#L471-L472

Added lines #L471 - L472 were not covered by tests

elapsed = time.time() - start
log.info("Total elapsed time = %g sec", elapsed)

Expand Down Expand Up @@ -878,6 +846,7 @@
)


# MIRI cosmic ray showers code
def find_faint_extended(
indata,
ingdq,
Expand All @@ -897,6 +866,7 @@
num_grps_masked=25,
max_extended_radius=200,
min_diffs_for_shower=10,
max_shower_amplitude=6,
):
"""
Parameters
Expand Down Expand Up @@ -931,6 +901,8 @@
The upper limit for the extension of saturation and jump
minimum_sigclip_groups : int
The minimum number of groups to use sigma clipping.
max_shower_amplitude : float
The maximum amplitude of shower artifacts to correct in DN/group


Returns
Expand All @@ -948,6 +920,7 @@
nints = data.shape[0]
ngrps = data.shape[1]
num_grps_donotuse = 0

for integ in range(nints):
for grp in range(ngrps):
if np.all(np.bitwise_and(gdq[integ, grp, :, :], donotuse_flag)):
Expand Down Expand Up @@ -1111,6 +1084,41 @@
num_grps_masked=num_grps_masked,
max_extended_radius=max_extended_radius
)

# Ensure that flagging showers didn't change final fluxes by more than the allowed amount
for intg in range(nints):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this loop needed? Is it possible to operate on the full data set, taking advantage of fast numpy looping, instead of slow explicit looping?

Copy link
Contributor Author

@drlaw1558 drlaw1558 Oct 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here the worry was that calling .copy() on the entire array could be extremely memory intensive for cases where there were a large number of integrations. Looping over the integration is less efficient, but it looks like it still only takes 3 seconds even for a TSO case broken into 18 ints, 30 groups, and 1024x1032 pixels.

# Consider DO_NOT_USE, SATURATION, and JUMP_DET flags
invalid_flags = donotuse_flag | sat_flag | jump_flag

# Approximate pre-shower rates
tempdata = indata[intg, :, :, :].copy()
# Ignore any groups flagged in the original gdq array
tempdata[ingdq[intg, :, :, :] & invalid_flags != 0] = np.nan
# Compute group differences
diff = np.diff(tempdata, axis=0)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=RuntimeWarning, message="All-NaN")
warnings.filterwarnings("ignore", category=RuntimeWarning, message="Mean of empty slice")
image1 = np.nanmean(diff, axis=0)

# Approximate post-shower rates
tempdata = indata[intg, :, :, :].copy()
# Ignore any groups flagged in the shower gdq array
tempdata[gdq[intg, :, :, :] & invalid_flags != 0] = np.nan
# Compute group differences
diff = np.diff(tempdata, axis=0)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=RuntimeWarning, message="All-NaN")
warnings.filterwarnings("ignore", category=RuntimeWarning, message="Mean of empty slice")
image2 = np.nanmean(diff, axis=0)

# Revert the group flags to the pre-shower flags for any pixels whose rates
# became NaN or changed by more than the amount reasonable for a real CR shower
# Note that max_shower_amplitude should now be in DN/group not DN/s
diff = np.abs(image1 - image2)
indx = np.where((np.isfinite(diff) == False) | (diff > max_shower_amplitude))
gdq[intg, :, indx[0], indx[1]] = ingdq[intg, :, indx[0], indx[1]]

return gdq, total_showers

def find_first_good_group(int_gdq, do_not_use):
Expand Down
31 changes: 1 addition & 30 deletions tests/test_jump.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,7 @@ def test_find_faint_extended():
jump_flag=4,
ellipse_expand=1.,
num_grps_masked=1,
max_shower_amplitude=10
)
# Check that all the expected samples in group 2 are flagged as jump and
# that they are not flagged outside
Expand All @@ -405,36 +406,6 @@ def test_find_faint_extended():
# Check that the flags are not applied in the 3rd group after the event
assert np.all(gdq[0, 4, 12:22, 14:23]) == 0

def test_find_faint_extended():
nint, ngrps, ncols, nrows = 1, 66, 5, 5
data = np.zeros(shape=(nint, ngrps, nrows, ncols), dtype=np.float32)
gdq = np.zeros_like(data, dtype=np.uint32)
pdq = np.zeros(shape=(nrows, ncols), dtype=np.uint32)
pdq[0, 0] = 1
pdq[1, 1] = 2147483648
# pdq = np.zeros(shape=(data.shape[2], data.shape[3]), dtype=np.uint8)
gain = 4
readnoise = np.ones(shape=(nrows, ncols), dtype=np.float32) * 6.0 * gain
rng = np.random.default_rng(12345)
data[0, 1:, 14:20, 15:20] = 6 * gain * 6.0 * np.sqrt(2)
data = data + rng.normal(size=(nint, ngrps, nrows, ncols)) * readnoise
gdq, num_showers = find_faint_extended(
data,
gdq,
pdq,
readnoise * np.sqrt(2),
1,
100,
snr_threshold=3,
min_shower_area=10,
inner=1,
outer=2.6,
sat_flag=2,
jump_flag=4,
ellipse_expand=1.1,
num_grps_masked=0,
)


# No shower is found because the event is identical in all ints
def test_find_faint_extended_sigclip():
Expand Down
Loading