diff --git a/src/server.ts b/src/server.ts index 2890c94c..bcad80e6 100644 --- a/src/server.ts +++ b/src/server.ts @@ -200,8 +200,9 @@ export class InversifyExpressServer { res: express.Response, next: express.NextFunction ) { - let mReq = _self._container.get(middlewareItem); - (mReq as any).httpContext = _self._getHttpContext(req); + const httpContext = _self._getHttpContext(req); + let mReq = httpContext.container.get(middlewareItem); + (mReq as any).httpContext = httpContext; mReq.handler(req, res, next); }; } diff --git a/test/base_middleware.test.ts b/test/base_middleware.test.ts index ec54a592..03ff9c31 100644 --- a/test/base_middleware.test.ts +++ b/test/base_middleware.test.ts @@ -120,7 +120,7 @@ describe("BaseMiddleware", () => { supertest(server.build()) .get("/") - .expect(200, `test@test.com`, () => { + .expect(200, `test@test.com`, (err) => { expect(principalInstanceCount).eq( 1, "Only one instance of HttpContext should be created per HTTP request!" @@ -138,11 +138,11 @@ describe("BaseMiddleware", () => { "test@test.com => isAuthenticated() => true", "Expected action to be invoked 3rd!" ); - done(); + done(err); }); }); - it("Should allow the middleware to inject services in a HTTP request scope", (done) => { + it("Should allow the middleware to inject services in a HTTP request scope", async () => { const TRACE_HEADER = "X-Trace-Id"; @@ -182,7 +182,7 @@ describe("BaseMiddleware", () => { } } - @controller("/") + @controller("/tracing-test") class TracingTestController extends BaseHttpController { constructor(@inject(TYPES.Service) private readonly service: Service) { @@ -209,25 +209,24 @@ describe("BaseMiddleware", () => { const expectedRequests = 100; let handledRequests = 0; - run(expectedRequests, (executionId: number) => { + await run(expectedRequests, (executionId: number) => { return supertest(api) - .get("/") + .get("/tracing-test") .set(TRACE_HEADER, `trace-id-${ executionId }`) .expect(200, `trace-id-${ executionId }`) .then(res => { handledRequests++; }); - }, (err?: Error) => { - expect(handledRequests).eq( - expectedRequests, - `Only ${ handledRequests } out of ${ expectedRequests } have been handled correctly` - ); - done(err); }); - }); - it("Should not allow services injected into a HTTP request scope to be accessible outside the request scope", (done) => { + expect(handledRequests).eq( + expectedRequests, + `Only ${ handledRequests } out of ${ expectedRequests } have been handled correctly` + ); + + }); + it("Should not allow services injected into a HTTP request scope to be accessible outside the request scope", async () => { const TYPES = { Transaction: Symbol.for("Transaction"), TransactionMiddleware: Symbol.for("TransactionMiddleware"), @@ -243,12 +242,12 @@ describe("BaseMiddleware", () => { next: express.NextFunction ) { this.bind(TYPES.Transaction) - .toConstantValue(`I am transaction #${++this.count}\n`); + .toConstantValue(`I am transaction #${++this.count}`); next(); } } - @controller("/") + @controller("/transactional-tests") class TransactionTestController extends BaseHttpController { constructor(@inject(TYPES.Transaction) @optional() private transaction: string) { @@ -273,34 +272,90 @@ describe("BaseMiddleware", () => { const container = new Container(); - container.bind(TYPES.TransactionMiddleware).to(TransactionMiddleware); + container.bind(TYPES.TransactionMiddleware).to(TransactionMiddleware).inSingletonScope(); const app = new InversifyExpressServer(container).build(); - supertest(app) - .get("/1") - .expect(200, "I am transaction #1", () => { + await supertest(app) + .get("/transactional-tests/1") + .expect(200, "I am transaction #1"); + await supertest(app) + .get("/transactional-tests/1") + .expect(200, "I am transaction #2"); + await supertest(app) + .get("/transactional-tests/2") + .expect(204, ""); + }); - supertest(app) - .get("/1") - .expect(200, "I am transaction #2", () => { + it("Should allow constructor injections from http-scope in middlewares", async () => { - supertest(app) - .get("/2") - .expect(200, "", () => done()); - }); - }); + const TYPES = { + Value: Symbol.for("Value"), + ReadValue: Symbol.for("ReadValue"), + HttpContextValueSetMiddleware: Symbol.for("HttpContextValueSetMiddleware"), + HttpContextValueReadMiddleware: Symbol.for("HttpContextValueReadMiddleware"), + }; - }); -}); + class HttpContextValueSetMiddleware extends BaseMiddleware { + public handler( + req: express.Request, + res: express.Response, + next: express.NextFunction + ) { + this.bind(TYPES.Value).toConstantValue(`MyValue`); + next(); + } + } + + class HttpContextValueReadMiddleware extends BaseMiddleware { + constructor(@inject(TYPES.Value) private value: string) { + super(); + } + + public handler( + req: express.Request, + res: express.Response, + next: express.NextFunction + ) { + this.bind(TYPES.ReadValue).toConstantValue(`${this.value} is read`); + next(); + } + } + + @controller("/http-scope-middleware-injection-test") + class MiddlewareInjectionTestController extends BaseHttpController { + + constructor(@inject(TYPES.ReadValue) @optional() private value: string) { + super(); + } -function run(parallelRuns: number, test: (executionId: number) => PromiseLike, done: (error?: Error) => void) { - const testTaskNo = (id: number) => function(cb: (err?: Error) => void) { - test(id).then(cb, cb); - }; + @httpGet( + "/get-value", + TYPES.HttpContextValueSetMiddleware, + TYPES.HttpContextValueReadMiddleware + ) + public getValue() { + return this.value; + } + } - const testTasks = Array.from({ length: parallelRuns }, (val: undefined, key: number) => testTaskNo(key)); + const container = new Container(); + + container.bind(TYPES.HttpContextValueReadMiddleware) + .to(HttpContextValueReadMiddleware); + container.bind(TYPES.HttpContextValueSetMiddleware) + .to(HttpContextValueSetMiddleware); + container.bind(TYPES.Value).toConstantValue("DefaultValue"); + const app = new InversifyExpressServer(container).build(); + + await supertest(app) + .get("/http-scope-middleware-injection-test/get-value") + .expect(200, "MyValue is read"); + }); +}); - async.parallel(testTasks, done); +function run(parallelRuns: number, test: (executionId: number) => PromiseLike) { + const testTasks = Array.from({ length: parallelRuns }, (val: undefined, key: number) => test(key)); + return Promise.all(testTasks); } function someTimeBetween(minimum: number, maximum: number) {