diff --git a/src/python/turicreate/toolkits/one_shot_object_detector/util/_augmentation.py b/src/python/turicreate/toolkits/one_shot_object_detector/util/_augmentation.py index 5367ca8019..4c13bfc71f 100644 --- a/src/python/turicreate/toolkits/one_shot_object_detector/util/_augmentation.py +++ b/src/python/turicreate/toolkits/one_shot_object_detector/util/_augmentation.py @@ -52,6 +52,15 @@ def preview_synthetic_training_data(data, backgrounds_tar = _tarfile.open(backgrounds_tar_path) backgrounds_tar.extractall() backgrounds = _tc.SArray("one_shot_backgrounds.sarray") + # We resize the background dimensions by half along each axis to reduce + # the disk footprint during augmentation, and also reduce the time + # taken to synthesize data. + backgrounds = backgrounds.apply(lambda im: _tc.image_analysis.resize( + im, + int(im.width/2), + int(im.height/2), + im.channels + )) # Option arguments to pass in to C++ Object Detector, if we use it: # {'mlmodel_path':'darknet.mlmodel', 'max_iterations' : 25} options_for_augmentation = { diff --git a/src/toolkits/object_detection/one_shot_object_detection/one_shot_object_detector.cpp b/src/toolkits/object_detection/one_shot_object_detection/one_shot_object_detector.cpp index f5c8a7e9f0..372ab97c7b 100644 --- a/src/toolkits/object_detection/one_shot_object_detection/one_shot_object_detector.cpp +++ b/src/toolkits/object_detection/one_shot_object_detection/one_shot_object_detector.cpp @@ -100,27 +100,21 @@ static std::map generate_column_index_map( return index_map; } -boost::gil::rgba8_image_t::view_t create_starter_image_view( - const flex_image &object_input) { +flex_image create_rgba_flex_image(const flex_image &object_input) { if (!(object_input.is_decoded())) { log_and_throw("Input object starter image is not decoded."); } - flex_image object = + flex_image rgba_flex_image = image_util::resize_image(object_input, object_input.m_width, object_input.m_height, 4, true) .to(); - if (!(object.is_decoded())) { + if (!(rgba_flex_image.is_decoded())) { log_and_throw("Resized object starter image is not decoded."); } - if (object.m_channels != 4) { + if (rgba_flex_image.m_channels != 4) { log_and_throw("Object image is not resized to be 4."); } - boost::gil::rgba8_image_t::view_t starter_image_view = interleaved_view( - object.m_width, object.m_height, - (boost::gil::rgba8_pixel_t *)(object.get_image_data()), - object.m_channels * object.m_width // row length in bytes - ); - return starter_image_view; + return rgba_flex_image; } std::pair @@ -149,12 +143,19 @@ create_synthetic_image_from_background_and_starter(const flex_image &starter, log_and_throw("Background image is not decoded into raw format."); } - boost::gil::rgba8_image_t::view_t starter_image_view = - create_starter_image_view(starter); + flex_image rgba_flex_image = create_rgba_flex_image(starter); + boost::gil::rgba8_image_t::const_view_t starter_image_view = + interleaved_view(rgba_flex_image.m_width, rgba_flex_image.m_height, + reinterpret_cast( + rgba_flex_image.get_image_data()), + rgba_flex_image.m_channels * + rgba_flex_image.m_width // row length in bytes + ); - boost::gil::rgb8_image_t::view_t background_view = interleaved_view( + boost::gil::rgb8_image_t::const_view_t background_view = interleaved_view( background.m_width, background.m_height, - (boost::gil::rgb8_pixel_t *)(background.get_image_data()), + reinterpret_cast( + background.get_image_data()), background.m_channels * background.m_width // row length in bytes ); flex_image synthetic_image = create_synthetic_image( diff --git a/src/toolkits/object_detection/one_shot_object_detection/one_shot_object_detector.hpp b/src/toolkits/object_detection/one_shot_object_detection/one_shot_object_detector.hpp index 179d37177b..a61a12710c 100644 --- a/src/toolkits/object_detection/one_shot_object_detection/one_shot_object_detector.hpp +++ b/src/toolkits/object_detection/one_shot_object_detection/one_shot_object_detector.hpp @@ -26,9 +26,8 @@ class EXPORT one_shot_object_detector: public ml_model_base { // Interface exposed via Unity server // TODO: augment -> train - gl_sframe augment(const gl_sframe &data, - const std::string& image_column_name, - const std::string& target_column_name, + gl_sframe augment(const gl_sframe &data, const std::string &image_column_name, + const std::string &target_column_name, const gl_sarray &backgrounds, std::map &options); diff --git a/src/toolkits/object_detection/one_shot_object_detection/util/superposition.cpp b/src/toolkits/object_detection/one_shot_object_detection/util/superposition.cpp index 53750aeb0c..b0dbf1f136 100644 --- a/src/toolkits/object_detection/one_shot_object_detection/util/superposition.cpp +++ b/src/toolkits/object_detection/one_shot_object_detection/util/superposition.cpp @@ -38,9 +38,10 @@ void superimpose_image(const boost::gil::rgb8_image_t::view_t &superimposed, }); } -flex_image create_synthetic_image(const boost::gil::rgba8_image_t::view_t &starter_image_view, - const boost::gil::rgb8_image_t::view_t &background_view, - ParameterSampler ¶meter_sampler) { +flex_image create_synthetic_image( + const boost::gil::rgba8_image_t::const_view_t &starter_image_view, + const boost::gil::rgb8_image_t::const_view_t &background_view, + ParameterSampler ¶meter_sampler) { boost::gil::rgba8_image_t background_rgba(boost::gil::rgba8_image_t::point_t(background_view.dimensions())); boost::gil::copy_and_convert_pixels( background_view, diff --git a/src/toolkits/object_detection/one_shot_object_detection/util/superposition.hpp b/src/toolkits/object_detection/one_shot_object_detection/util/superposition.hpp index 442856b86e..8a7f1da297 100644 --- a/src/toolkits/object_detection/one_shot_object_detection/util/superposition.hpp +++ b/src/toolkits/object_detection/one_shot_object_detection/util/superposition.hpp @@ -15,9 +15,10 @@ namespace turi { namespace one_shot_object_detection { namespace data_augmentation { -flex_image create_synthetic_image(const boost::gil::rgba8_image_t::view_t &starter_image_view, - const boost::gil::rgb8_image_t::view_t &background_view, - ParameterSampler ¶meter_sampler); +flex_image create_synthetic_image( + const boost::gil::rgba8_image_t::const_view_t &starter_image_view, + const boost::gil::rgb8_image_t::const_view_t &background_view, + ParameterSampler ¶meter_sampler); } // data_augmentation } // one_shot_object_detection