diff --git a/app/src/main/java/com/techcourse/DispatcherServlet.java b/app/src/main/java/com/techcourse/DispatcherServlet.java index 277d8eed9a..2036cfb3e8 100644 --- a/app/src/main/java/com/techcourse/DispatcherServlet.java +++ b/app/src/main/java/com/techcourse/DispatcherServlet.java @@ -4,42 +4,92 @@ import jakarta.servlet.http.HttpServlet; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; +import java.util.ArrayList; +import java.util.List; +import java.util.NoSuchElementException; +import java.util.stream.Collectors; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import webmvc.org.springframework.web.servlet.ModelAndView; +import webmvc.org.springframework.web.servlet.mvc.tobe.HandlerAdapter; +import webmvc.org.springframework.web.servlet.mvc.tobe.HandlerMapping; import webmvc.org.springframework.web.servlet.view.JspView; public class DispatcherServlet extends HttpServlet { private static final long serialVersionUID = 1L; private static final Logger log = LoggerFactory.getLogger(DispatcherServlet.class); + private static final String BASE_PACKAGE_PATH = "com.techcourse"; - private ManualHandlerMapping manualHandlerMapping; + private final List handlerMappings; + private final List handlerAdapters; public DispatcherServlet() { + handlerMappings = new ArrayList<>(); + handlerAdapters = new ArrayList<>(); } @Override public void init() { - manualHandlerMapping = new ManualHandlerMapping(); - manualHandlerMapping.initialize(); + initHandlerMappings(); + initHandlerAdapters(); + } + + private void initHandlerMappings() { + final List handlerMappingInstances = HandlerMappingFactory.getHandlerMappings(BASE_PACKAGE_PATH) + .stream() + .peek(HandlerMapping::initialize) + .collect(Collectors.toList()); + handlerMappings.addAll(handlerMappingInstances); + } + + private void initHandlerAdapters() { + final List handlerAdapterInstances = HandlerAdapterFactory.getHandlerAdapters(); + handlerAdapters.addAll(handlerAdapterInstances); } - @Override - protected void service(final HttpServletRequest request, final HttpServletResponse response) throws ServletException { - final String requestURI = request.getRequestURI(); - log.debug("Method : {}, Request URI : {}", request.getMethod(), requestURI); + @Override + protected void service(final HttpServletRequest request, final HttpServletResponse response) + throws ServletException { + log.debug("Method : {}, Request URI : {}", request.getMethod(), request.getRequestURI()); try { - final var controller = manualHandlerMapping.getHandler(requestURI); - final var viewName = controller.execute(request, response); - move(viewName, request, response); - } catch (Throwable e) { + process(request, response); + } catch (Exception e) { log.error("Exception : {}", e.getMessage(), e); throw new ServletException(e.getMessage()); } } - private void move(final String viewName, final HttpServletRequest request, final HttpServletResponse response) throws Exception { + private void process(final HttpServletRequest request, final HttpServletResponse response) + throws Exception { + final Object handler = getHandler(request); + final HandlerAdapter handlerAdapter = getHandlerAdapter(handler); + final ModelAndView modelAndView = handlerAdapter.handle(request, response, handler); + move(modelAndView, request, response); + } + + private Object getHandler(final HttpServletRequest request) { + return handlerMappings.stream() + .filter(mapping -> mapping.getHandler(request) != null) + .findFirst() + .orElseThrow(() -> new NoSuchElementException("해당하는 HandlerMapping이 없습니다.")) + .getHandler(request); + } + + private HandlerAdapter getHandlerAdapter(final Object handler) { + return handlerAdapters.stream() + .filter(adapter -> adapter.supports(handler)) + .findFirst() + .orElseThrow(() -> new NoSuchElementException("해당하는 HandlerAdapter가 없습니다.")); + } + + private void move( + final ModelAndView modelAndView, + final HttpServletRequest request, + final HttpServletResponse response + ) throws Exception { + final String viewName = modelAndView.getViewName(); if (viewName.startsWith(JspView.REDIRECT_PREFIX)) { response.sendRedirect(viewName.substring(JspView.REDIRECT_PREFIX.length())); return; diff --git a/app/src/main/java/com/techcourse/HandlerAdapterFactory.java b/app/src/main/java/com/techcourse/HandlerAdapterFactory.java new file mode 100644 index 0000000000..542ef51a32 --- /dev/null +++ b/app/src/main/java/com/techcourse/HandlerAdapterFactory.java @@ -0,0 +1,16 @@ +package com.techcourse; + +import java.util.List; +import webmvc.org.springframework.web.servlet.mvc.tobe.AnnotationHandlerAdapter; +import webmvc.org.springframework.web.servlet.mvc.tobe.HandlerAdapter; +import webmvc.org.springframework.web.servlet.mvc.tobe.ManualHandlerAdapter; + +public class HandlerAdapterFactory { + + public static List getHandlerAdapters() { + return List.of( + new AnnotationHandlerAdapter(), + new ManualHandlerAdapter() + ); + } +} diff --git a/app/src/main/java/com/techcourse/HandlerMappingFactory.java b/app/src/main/java/com/techcourse/HandlerMappingFactory.java new file mode 100644 index 0000000000..2d2b2219e5 --- /dev/null +++ b/app/src/main/java/com/techcourse/HandlerMappingFactory.java @@ -0,0 +1,15 @@ +package com.techcourse; + +import java.util.List; +import webmvc.org.springframework.web.servlet.mvc.tobe.AnnotationHandlerMapping; +import webmvc.org.springframework.web.servlet.mvc.tobe.HandlerMapping; + +public class HandlerMappingFactory { + + public static List getHandlerMappings(final Object... basePackagePath) { + return List.of( + new AnnotationHandlerMapping(basePackagePath), + new ManualHandlerMapping() + ); + } +} diff --git a/app/src/main/java/com/techcourse/ManualHandlerMapping.java b/app/src/main/java/com/techcourse/ManualHandlerMapping.java index a54863caf8..2da0caf5ce 100644 --- a/app/src/main/java/com/techcourse/ManualHandlerMapping.java +++ b/app/src/main/java/com/techcourse/ManualHandlerMapping.java @@ -1,20 +1,25 @@ package com.techcourse; -import com.techcourse.controller.*; +import com.techcourse.controller.LoginController; +import com.techcourse.controller.LoginViewController; +import com.techcourse.controller.LogoutController; +import com.techcourse.controller.RegisterController; +import com.techcourse.controller.RegisterViewController; +import jakarta.servlet.http.HttpServletRequest; +import java.util.HashMap; +import java.util.Map; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import webmvc.org.springframework.web.servlet.mvc.asis.Controller; import webmvc.org.springframework.web.servlet.mvc.asis.ForwardController; +import webmvc.org.springframework.web.servlet.mvc.tobe.HandlerMapping; -import java.util.HashMap; -import java.util.Map; - -public class ManualHandlerMapping { +public class ManualHandlerMapping implements HandlerMapping { private static final Logger log = LoggerFactory.getLogger(ManualHandlerMapping.class); - private static final Map controllers = new HashMap<>(); + @Override public void initialize() { controllers.put("/", new ForwardController("/index.jsp")); controllers.put("/login", new LoginController()); @@ -25,10 +30,12 @@ public void initialize() { log.info("Initialized Handler Mapping!"); controllers.keySet() - .forEach(path -> log.info("Path : {}, Controller : {}", path, controllers.get(path).getClass())); + .forEach(path -> log.info("Path : {}, Controller : {}", path, controllers.get(path).getClass())); } - public Controller getHandler(final String requestURI) { + @Override + public Object getHandler(final HttpServletRequest request) { + final String requestURI = request.getRequestURI(); log.debug("Request Mapping Uri : {}", requestURI); return controllers.get(requestURI); } diff --git a/app/src/test/java/com/techcourse/DispatcherServletTest.java b/app/src/test/java/com/techcourse/DispatcherServletTest.java new file mode 100644 index 0000000000..890dbcc960 --- /dev/null +++ b/app/src/test/java/com/techcourse/DispatcherServletTest.java @@ -0,0 +1,31 @@ +package com.techcourse; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +class DispatcherServletTest { + + @DisplayName("해당하는 HandlerMapping 이 존재하지 않으면 예외를 반환한다.") + @Test + void notExistHandlerMappingThrowsException() { + // given + final DispatcherServlet dispatcherServlet = new DispatcherServlet(); + dispatcherServlet.init(); + final HttpServletRequest request = mock(HttpServletRequest.class); + when(request.getRequestURI()).thenReturn("/not-exist"); + when(request.getMethod()).thenReturn("GET"); + + // when + // then + assertThatThrownBy(() -> dispatcherServlet.service(request, mock(HttpServletResponse.class))) + .isInstanceOf(ServletException.class) + .hasMessage("해당하는 HandlerMapping이 없습니다."); + } +} diff --git a/mvc/src/main/java/webmvc/org/springframework/web/servlet/ModelAndView.java b/mvc/src/main/java/webmvc/org/springframework/web/servlet/ModelAndView.java index ff8e24553f..fdf4e585d3 100644 --- a/mvc/src/main/java/webmvc/org/springframework/web/servlet/ModelAndView.java +++ b/mvc/src/main/java/webmvc/org/springframework/web/servlet/ModelAndView.java @@ -30,4 +30,8 @@ public Map getModel() { public View getView() { return view; } + + public String getViewName() { + return view.getName(); + } } diff --git a/mvc/src/main/java/webmvc/org/springframework/web/servlet/View.java b/mvc/src/main/java/webmvc/org/springframework/web/servlet/View.java index 4499f36866..ca5e002b7e 100644 --- a/mvc/src/main/java/webmvc/org/springframework/web/servlet/View.java +++ b/mvc/src/main/java/webmvc/org/springframework/web/servlet/View.java @@ -2,9 +2,11 @@ import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; - import java.util.Map; public interface View { + void render(Map model, HttpServletRequest request, HttpServletResponse response) throws Exception; + + String getName(); } diff --git a/mvc/src/main/java/webmvc/org/springframework/web/servlet/mvc/tobe/AnnotationHandlerAdapter.java b/mvc/src/main/java/webmvc/org/springframework/web/servlet/mvc/tobe/AnnotationHandlerAdapter.java new file mode 100644 index 0000000000..3a1fb33c8a --- /dev/null +++ b/mvc/src/main/java/webmvc/org/springframework/web/servlet/mvc/tobe/AnnotationHandlerAdapter.java @@ -0,0 +1,21 @@ +package webmvc.org.springframework.web.servlet.mvc.tobe; + +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import webmvc.org.springframework.web.servlet.ModelAndView; + +public class AnnotationHandlerAdapter implements HandlerAdapter { + + @Override + public boolean supports(final Object handler) { + return handler instanceof HandlerExecution; + } + + @Override + public ModelAndView handle(final HttpServletRequest request, final HttpServletResponse response, + final Object handler) + throws Exception { + final HandlerExecution handlerExecution = (HandlerExecution) handler; + return handlerExecution.handle(request, response); + } +} diff --git a/mvc/src/main/java/webmvc/org/springframework/web/servlet/mvc/tobe/AnnotationHandlerMapping.java b/mvc/src/main/java/webmvc/org/springframework/web/servlet/mvc/tobe/AnnotationHandlerMapping.java index 1bf6dc7f6c..0e977cc7a6 100644 --- a/mvc/src/main/java/webmvc/org/springframework/web/servlet/mvc/tobe/AnnotationHandlerMapping.java +++ b/mvc/src/main/java/webmvc/org/springframework/web/servlet/mvc/tobe/AnnotationHandlerMapping.java @@ -16,7 +16,7 @@ import web.org.springframework.web.bind.annotation.RequestMapping; import web.org.springframework.web.bind.annotation.RequestMethod; -public class AnnotationHandlerMapping { +public class AnnotationHandlerMapping implements HandlerMapping { private static final Logger log = LoggerFactory.getLogger(AnnotationHandlerMapping.class); @@ -28,6 +28,7 @@ public AnnotationHandlerMapping(final Object... basePackages) { this.handlerExecutions = new HashMap<>(); } + @Override public void initialize() { log.info("Initialized AnnotationHandlerMapping!"); makeHandlerExecutions(basePackages); @@ -95,6 +96,7 @@ private Map convertHandlerExecutions(final Object )); } + @Override public Object getHandler(final HttpServletRequest request) { final HandlerKey handlerKey = makeHandlerKey(request); return handlerExecutions.get(handlerKey); diff --git a/mvc/src/main/java/webmvc/org/springframework/web/servlet/mvc/tobe/HandlerAdapter.java b/mvc/src/main/java/webmvc/org/springframework/web/servlet/mvc/tobe/HandlerAdapter.java new file mode 100644 index 0000000000..13351a1575 --- /dev/null +++ b/mvc/src/main/java/webmvc/org/springframework/web/servlet/mvc/tobe/HandlerAdapter.java @@ -0,0 +1,12 @@ +package webmvc.org.springframework.web.servlet.mvc.tobe; + +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import webmvc.org.springframework.web.servlet.ModelAndView; + +public interface HandlerAdapter { + + boolean supports(Object handler); + + ModelAndView handle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception; +} diff --git a/mvc/src/main/java/webmvc/org/springframework/web/servlet/mvc/tobe/HandlerMapping.java b/mvc/src/main/java/webmvc/org/springframework/web/servlet/mvc/tobe/HandlerMapping.java new file mode 100644 index 0000000000..ba0798f4da --- /dev/null +++ b/mvc/src/main/java/webmvc/org/springframework/web/servlet/mvc/tobe/HandlerMapping.java @@ -0,0 +1,10 @@ +package webmvc.org.springframework.web.servlet.mvc.tobe; + +import jakarta.servlet.http.HttpServletRequest; + +public interface HandlerMapping { + + Object getHandler(final HttpServletRequest request); + + void initialize(); +} diff --git a/mvc/src/main/java/webmvc/org/springframework/web/servlet/mvc/tobe/ManualHandlerAdapter.java b/mvc/src/main/java/webmvc/org/springframework/web/servlet/mvc/tobe/ManualHandlerAdapter.java new file mode 100644 index 0000000000..88f73ded3e --- /dev/null +++ b/mvc/src/main/java/webmvc/org/springframework/web/servlet/mvc/tobe/ManualHandlerAdapter.java @@ -0,0 +1,27 @@ +package webmvc.org.springframework.web.servlet.mvc.tobe; + +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import webmvc.org.springframework.web.servlet.ModelAndView; +import webmvc.org.springframework.web.servlet.mvc.asis.Controller; +import webmvc.org.springframework.web.servlet.view.JspView; + +public class ManualHandlerAdapter implements HandlerAdapter { + + @Override + public boolean supports(final Object handler) { + return handler instanceof Controller; + } + + @Override + public ModelAndView handle( + final HttpServletRequest request, + final HttpServletResponse response, + final Object handler + ) + throws Exception { + final Controller controller = (Controller) handler; + final String path = controller.execute(request, response); + return new ModelAndView(new JspView(path)); + } +} diff --git a/mvc/src/main/java/webmvc/org/springframework/web/servlet/view/JsonView.java b/mvc/src/main/java/webmvc/org/springframework/web/servlet/view/JsonView.java index b42c3466f0..2c0ba3fc3b 100644 --- a/mvc/src/main/java/webmvc/org/springframework/web/servlet/view/JsonView.java +++ b/mvc/src/main/java/webmvc/org/springframework/web/servlet/view/JsonView.java @@ -11,4 +11,9 @@ public class JsonView implements View { @Override public void render(final Map model, final HttpServletRequest request, HttpServletResponse response) throws Exception { } + + @Override + public String getName() { + return null; + } } diff --git a/mvc/src/main/java/webmvc/org/springframework/web/servlet/view/JspView.java b/mvc/src/main/java/webmvc/org/springframework/web/servlet/view/JspView.java index 3f4cc906ff..aaa69da116 100644 --- a/mvc/src/main/java/webmvc/org/springframework/web/servlet/view/JspView.java +++ b/mvc/src/main/java/webmvc/org/springframework/web/servlet/view/JspView.java @@ -13,8 +13,10 @@ public class JspView implements View { private static final Logger log = LoggerFactory.getLogger(JspView.class); public static final String REDIRECT_PREFIX = "redirect:"; + private final String viewName; public JspView(final String viewName) { + this.viewName = viewName; } @Override @@ -28,4 +30,9 @@ public void render(final Map model, final HttpServletRequest request, // todo } + + @Override + public String getName() { + return viewName; + } } diff --git a/mvc/src/test/java/samples/TestController.java b/mvc/src/test/java/samples/TestController.java index 1f0e4acfb3..574d22934d 100644 --- a/mvc/src/test/java/samples/TestController.java +++ b/mvc/src/test/java/samples/TestController.java @@ -18,7 +18,7 @@ public class TestController { @RequestMapping(value = "/get-test", method = RequestMethod.GET) public ModelAndView findUserId(final HttpServletRequest request, final HttpServletResponse response) { log.info("test controller get method"); - final var modelAndView = new ModelAndView(new JspView("")); + final var modelAndView = new ModelAndView(new JspView("test")); modelAndView.addObject("id", request.getAttribute("id")); return modelAndView; } @@ -26,7 +26,7 @@ public ModelAndView findUserId(final HttpServletRequest request, final HttpServl @RequestMapping(value = "/post-test", method = RequestMethod.POST) public ModelAndView save(final HttpServletRequest request, final HttpServletResponse response) { log.info("test controller post method"); - final var modelAndView = new ModelAndView(new JspView("")); + final var modelAndView = new ModelAndView(new JspView("test")); modelAndView.addObject("id", request.getAttribute("id")); return modelAndView; } diff --git a/mvc/src/test/java/webmvc/org/springframework/web/servlet/mvc/tobe/AnnotationHandlerAdapterTest.java b/mvc/src/test/java/webmvc/org/springframework/web/servlet/mvc/tobe/AnnotationHandlerAdapterTest.java new file mode 100644 index 0000000000..8c988bc25b --- /dev/null +++ b/mvc/src/test/java/webmvc/org/springframework/web/servlet/mvc/tobe/AnnotationHandlerAdapterTest.java @@ -0,0 +1,54 @@ +package webmvc.org.springframework.web.servlet.mvc.tobe; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import samples.TestController; +import webmvc.org.springframework.web.servlet.ModelAndView; + +class AnnotationHandlerAdapterTest { + + @DisplayName("처리할 수 있는 핸들러인지 확인한다.") + @Test + void supports() { + // given + final var handlerAdapter = new AnnotationHandlerAdapter(); + + // when + // then + assertThat(handlerAdapter.supports(new HandlerExecution(null, null))).isTrue(); + } + + @DisplayName("처리할 수 없는 핸들러는 false 를 반환한다.") + @Test + void supportsReturnFalse() { + // given + final var handlerAdapter = new AnnotationHandlerAdapter(); + + // when + // then + assertThat(handlerAdapter.supports(new Object())).isFalse(); + } + + @DisplayName("핸들러의 동작을 수행한다.") + @Test + void handle() throws Exception { + // given + final var handlerAdapter = new AnnotationHandlerAdapter(); + + // when + final HttpServletRequest request = mock(HttpServletRequest.class); + when(request.getAttribute("id")).thenReturn("1"); + final ModelAndView modelAndView = handlerAdapter.handle(request, null, new HandlerExecution(new TestController(), + TestController.class.getMethod("findUserId", HttpServletRequest.class, HttpServletResponse.class))); + + // then + assertThat(modelAndView.getModel().get("id")).isEqualTo("1"); + assertThat(modelAndView.getViewName()).isEqualTo("test"); + } +} diff --git a/mvc/src/test/java/webmvc/org/springframework/web/servlet/mvc/tobe/AnnotationHandlerMappingTest.java b/mvc/src/test/java/webmvc/org/springframework/web/servlet/mvc/tobe/AnnotationHandlerMappingTest.java index dcec215a3f..c6ae5b490f 100644 --- a/mvc/src/test/java/webmvc/org/springframework/web/servlet/mvc/tobe/AnnotationHandlerMappingTest.java +++ b/mvc/src/test/java/webmvc/org/springframework/web/servlet/mvc/tobe/AnnotationHandlerMappingTest.java @@ -1,14 +1,14 @@ package webmvc.org.springframework.web.servlet.mvc.tobe; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - class AnnotationHandlerMappingTest { private AnnotationHandlerMapping handlerMapping; diff --git a/study/src/test/java/di/stage3/context/DIContainer.java b/study/src/test/java/di/stage3/context/DIContainer.java index b62feb1ed3..d0a73c50ea 100644 --- a/study/src/test/java/di/stage3/context/DIContainer.java +++ b/study/src/test/java/di/stage3/context/DIContainer.java @@ -1,5 +1,7 @@ package di.stage3.context; +import java.util.Arrays; +import java.util.HashSet; import java.util.Set; /** @@ -10,11 +12,32 @@ class DIContainer { private final Set beans; public DIContainer(final Set> classes) { - this.beans = Set.of(); + this.beans = new HashSet<>(); + + final Class userDao = classes.stream() + .filter(clazz -> Arrays.asList(clazz.getInterfaces()).contains(UserDao.class)) + .findFirst() + .orElseThrow(() -> new RuntimeException("UserDao가 존재하지 않습니다.")); + + final Class userService = classes.stream() + .filter(clazz -> clazz.isAssignableFrom(UserService.class)) + .findFirst() + .orElseThrow(() -> new RuntimeException("UserService가 존재하지 않습니다.")); + + try { + final UserDao userDaoInstance = (UserDao) userDao.getConstructor().newInstance(); + beans.add(userDaoInstance); + beans.add(userService.getConstructor(UserDao.class).newInstance(userDaoInstance)); + } catch (Exception e) { + throw new RuntimeException(e + " Bean 등록 시 예외가 발생했습니다."); + } } @SuppressWarnings("unchecked") public T getBean(final Class aClass) { - return null; + return (T) beans.stream() + .filter(bean -> aClass.isAssignableFrom(bean.getClass())) + .findFirst() + .orElseThrow(() -> new RuntimeException("해당하는 클래스가 존재하지 않습니다.")); } } diff --git a/study/src/test/java/di/stage4/annotations/ClassPathScanner.java b/study/src/test/java/di/stage4/annotations/ClassPathScanner.java index 9dab1fd9c4..ccd9b27f4f 100644 --- a/study/src/test/java/di/stage4/annotations/ClassPathScanner.java +++ b/study/src/test/java/di/stage4/annotations/ClassPathScanner.java @@ -1,10 +1,17 @@ package di.stage4.annotations; +import java.util.HashSet; import java.util.Set; +import org.reflections.Reflections; public class ClassPathScanner { public static Set> getAllClassesInPackage(final String packageName) { - return null; + final Reflections reflections = new Reflections(packageName); + final Set> response = new HashSet<>(reflections.getTypesAnnotatedWith(Inject.class)); + response.addAll(reflections.getTypesAnnotatedWith(Service.class)); + response.addAll(reflections.getTypesAnnotatedWith(Repository.class)); + + return response; } } diff --git a/study/src/test/java/di/stage4/annotations/DIContainer.java b/study/src/test/java/di/stage4/annotations/DIContainer.java index 9248ecad7e..c541c4eaf4 100644 --- a/study/src/test/java/di/stage4/annotations/DIContainer.java +++ b/study/src/test/java/di/stage4/annotations/DIContainer.java @@ -1,5 +1,9 @@ package di.stage4.annotations; +import java.lang.reflect.Constructor; +import java.lang.reflect.Field; +import java.util.Arrays; +import java.util.HashSet; import java.util.Set; /** @@ -10,15 +14,62 @@ class DIContainer { private final Set beans; public DIContainer(final Set> classes) { - this.beans = Set.of(); + this.beans = new HashSet<>(); + for (final Class clazz : classes) { + try { + final Constructor constructor = clazz.getDeclaredConstructor(); + constructor.setAccessible(true); + beans.add(constructor.newInstance()); + } catch (Exception e) { + throw new RuntimeException(e + " Bean 등록 시 예외가 발생했습니다."); + } + } + initialize(classes); + } + + private void initialize(final Set> classes) { + for (final Class injectableClass : classes) { + updateFields(injectableClass); + } + } + + private void updateFields(final Class injectableClass) { + try { + setFields(injectableClass); + } catch (IllegalAccessException e) { + throw new RuntimeException("필드에 접근할 수 없습니다."); + } + } + + private void setFields(final Class injectableClass) throws IllegalAccessException { + final Field[] fields = injectableClass.getDeclaredFields(); + for (final Field field : fields) { + if (!field.isAnnotationPresent(Inject.class)) { + continue; + } + final Object instance = getBean(injectableClass); + field.setAccessible(true); + field.set(instance, getBean(field.getType())); + } } public static DIContainer createContainerForPackage(final String rootPackageName) { - return null; + final Set> injectableClasses = ClassPathScanner.getAllClassesInPackage(rootPackageName); + + return new DIContainer(injectableClasses); } @SuppressWarnings("unchecked") public T getBean(final Class aClass) { - return null; + if (aClass.isInterface()) { + return (T) beans.stream() + .filter(bean -> Arrays.asList(bean.getClass().getInterfaces()).contains(aClass)) + .findFirst() + .orElseThrow(() -> new RuntimeException("해당하는 클래스가 존재하지 않습니다.")); + } + return (T) beans.stream() + .filter(bean -> aClass.isAssignableFrom(bean.getClass())) + .findFirst() + .orElseThrow(() -> new RuntimeException("해당하는 클래스가 존재하지 않습니다.")); } }