diff --git a/packages/backend-core/src/cache/base/index.ts b/packages/backend-core/src/cache/base/index.ts index 433941b5c7..d59a3e9adf 100644 --- a/packages/backend-core/src/cache/base/index.ts +++ b/packages/backend-core/src/cache/base/index.ts @@ -129,6 +129,29 @@ export default class BaseCache { } } + async withCacheWithDynamicTTL( + key: string, + fetchFn: () => Promise<{ value: T; ttl: number | null }>, + opts = { useTenancy: true } + ): Promise { + const cachedValue = await this.get(key, opts) + if (cachedValue) { + return cachedValue + } + + try { + const fetchedResponse = await fetchFn() + const { value, ttl } = fetchedResponse + await this.store(key, value, ttl, { + useTenancy: opts.useTenancy, + }) + return value + } catch (err) { + console.error("Error fetching before cache - ", err) + throw err + } + } + async bustCache(key: string) { const client = await this.getClient() try { diff --git a/packages/backend-core/src/cache/generic.ts b/packages/backend-core/src/cache/generic.ts index 2d6d8b9472..22bdf47f00 100644 --- a/packages/backend-core/src/cache/generic.ts +++ b/packages/backend-core/src/cache/generic.ts @@ -2,14 +2,15 @@ import BaseCache from "./base" const GENERIC = new BaseCache() -export enum CacheKey { - CHECKLIST = "checklist", - INSTALLATION = "installation", - ANALYTICS_ENABLED = "analyticsEnabled", - UNIQUE_TENANT_ID = "uniqueTenantId", - EVENTS = "events", - BACKFILL_METADATA = "backfillMetadata", - EVENTS_RATE_LIMIT = "eventsRateLimit", +export const CacheKey = { + CHECKLIST: "checklist", + INSTALLATION: "installation", + ANALYTICS_ENABLED: "analyticsEnabled", + UNIQUE_TENANT_ID: "uniqueTenantId", + EVENTS: "events", + BACKFILL_METADATA: "backfillMetadata", + EVENTS_RATE_LIMIT: "eventsRateLimit", + OAUTH2_TOKEN: (configId: string) => `oauth2Token_${configId}`, } export enum TTL { @@ -29,5 +30,8 @@ export const destroy = (...args: Parameters) => export const withCache = ( ...args: Parameters> ) => GENERIC.withCache(...args) +export const withCacheWithDynamicTTL = ( + ...args: Parameters> +) => GENERIC.withCacheWithDynamicTTL(...args) export const bustCache = (...args: Parameters) => GENERIC.bustCache(...args) diff --git a/packages/server/src/integrations/rest.ts b/packages/server/src/integrations/rest.ts index 24ab9df425..219973fcb7 100644 --- a/packages/server/src/integrations/rest.ts +++ b/packages/server/src/integrations/rest.ts @@ -418,7 +418,7 @@ export class RestIntegration implements IntegrationBase { return headers } - async _req(query: RestQuery) { + async _req(query: RestQuery, retry401 = true): Promise { const { path = "", queryString = "", @@ -480,6 +480,14 @@ export class RestIntegration implements IntegrationBase { throw new Error("Cannot connect to URL.") } const response = await fetch(url, input) + if ( + response.status === 401 && + authConfigType === RestAuthType.OAUTH2 && + retry401 + ) { + await sdk.oauth2.cleanStoredToken(authConfigId!) + return await this._req(query, false) + } return await this.parseResponse(response, pagination) } diff --git a/packages/server/src/integrations/tests/rest.spec.ts b/packages/server/src/integrations/tests/rest.spec.ts index 71ff711352..7e4ed2172c 100644 --- a/packages/server/src/integrations/tests/rest.spec.ts +++ b/packages/server/src/integrations/tests/rest.spec.ts @@ -278,6 +278,29 @@ describe("REST Integration", () => { expect(data).toEqual({ foo: "bar" }) }) + function nockTokenCredentials( + oauth2Url: string, + clientId: string, + password: string, + resultCode: number, + resultBody: any + ) { + const url = new URL(oauth2Url) + const token = generator.guid() + nock(url.origin) + .post(url.pathname, { + grant_type: "client_credentials", + }) + .basicAuth({ user: clientId, pass: password }) + .reply(200, { token_type: "Bearer", access_token: token }) + + return nock("https://example.com", { + reqheaders: { Authorization: `Bearer ${token}` }, + }) + .get("/") + .reply(resultCode, resultBody) + } + it("adds OAuth2 auth (via header)", async () => { const oauth2Url = generator.url() const secret = generator.hash() @@ -290,22 +313,11 @@ describe("REST Integration", () => { grantType: OAuth2GrantType.CLIENT_CREDENTIALS, }) - const token = generator.guid() - - const url = new URL(oauth2Url) - nock(url.origin) - .post(url.pathname, { - grant_type: "client_credentials", - }) - .basicAuth({ user: oauthConfig.clientId, pass: secret }) - .reply(200, { token_type: "Bearer", access_token: token }) - - nock("https://example.com", { - reqheaders: { Authorization: `Bearer ${token}` }, + nockTokenCredentials(oauth2Url, oauthConfig.clientId, secret, 200, { + foo: "bar", }) - .get("/") - .reply(200, { foo: "bar" }) - const { data } = await config.doInContext( + + const { data, info } = await config.doInContext( config.appId, async () => await integration.read({ @@ -314,6 +326,7 @@ describe("REST Integration", () => { }) ) expect(data).toEqual({ foo: "bar" }) + expect(info.code).toEqual(200) }) it("adds OAuth2 auth (via body)", async () => { @@ -348,7 +361,8 @@ describe("REST Integration", () => { }) .get("/") .reply(200, { foo: "bar" }) - const { data } = await config.doInContext( + + const { data, info } = await config.doInContext( config.appId, async () => await integration.read({ @@ -357,6 +371,95 @@ describe("REST Integration", () => { }) ) expect(data).toEqual({ foo: "bar" }) + expect(info.code).toEqual(200) + }) + + it("handles OAuth2 auth cached expired token", async () => { + const oauth2Url = generator.url() + const secret = generator.hash() + const { config: oauthConfig } = await config.api.oauth2.create({ + name: generator.guid(), + url: oauth2Url, + clientId: generator.guid(), + clientSecret: secret, + method: OAuth2CredentialsMethod.HEADER, + grantType: OAuth2GrantType.CLIENT_CREDENTIALS, + }) + + nockTokenCredentials(oauth2Url, oauthConfig.clientId, secret, 401, {}) + const token2Request = nockTokenCredentials( + oauth2Url, + oauthConfig.clientId, + secret, + 200, + { + foo: "bar", + } + ) + + const { data, info } = await config.doInContext( + config.appId, + async () => + await integration.read({ + authConfigId: oauthConfig._id, + authConfigType: RestAuthType.OAUTH2, + }) + ) + + expect(data).toEqual({ foo: "bar" }) + expect(info.code).toEqual(200) + expect(token2Request.isDone()).toBe(true) + }) + + it("does not loop when handling OAuth2 auth cached expired token", async () => { + const oauth2Url = generator.url() + const secret = generator.hash() + const { config: oauthConfig } = await config.api.oauth2.create({ + name: generator.guid(), + url: oauth2Url, + clientId: generator.guid(), + clientSecret: secret, + method: OAuth2CredentialsMethod.HEADER, + grantType: OAuth2GrantType.CLIENT_CREDENTIALS, + }) + + const firstRequest = nockTokenCredentials( + oauth2Url, + oauthConfig.clientId, + secret, + 401, + {} + ) + const secondRequest = nockTokenCredentials( + oauth2Url, + oauthConfig.clientId, + secret, + 401, + {} + ) + const thirdRequest = nockTokenCredentials( + oauth2Url, + oauthConfig.clientId, + secret, + 200, + { foo: "bar" } + ) + + const { data, info } = await config.doInContext( + config.appId, + async () => + await integration.read({ + authConfigId: oauthConfig._id, + authConfigType: RestAuthType.OAUTH2, + }) + ) + + expect(info.code).toEqual(401) + expect(data).toEqual({}) + + expect(firstRequest.isDone()).toBe(true) + expect(secondRequest.isDone()).toBe(true) + expect(thirdRequest.isDone()).toBe(false) }) }) diff --git a/packages/server/src/sdk/app/oauth2/tests/utils.spec.ts b/packages/server/src/sdk/app/oauth2/tests/utils.spec.ts index 13d8b7e980..5739f1a794 100644 --- a/packages/server/src/sdk/app/oauth2/tests/utils.spec.ts +++ b/packages/server/src/sdk/app/oauth2/tests/utils.spec.ts @@ -69,6 +69,53 @@ describe("oauth2 utils", () => { expect(response).toEqual(expect.stringMatching(/^Bearer .+/)) }) + it("uses cached value if available", async () => { + const oauthConfig = await config.doInContext(config.appId, () => + sdk.oauth2.create({ + name: generator.guid(), + url: `${keycloakUrl}/realms/myrealm/protocol/openid-connect/token`, + clientId: "my-client", + clientSecret: "my-secret", + method, + grantType, + }) + ) + + const firstToken = await config.doInContext(config.appId, () => + getToken(oauthConfig._id) + ) + const secondToken = await config.doInContext(config.appId, () => + getToken(oauthConfig._id) + ) + + expect(firstToken).toEqual(secondToken) + }) + + it("refetches value if cache expired", async () => { + const oauthConfig = await config.doInContext(config.appId, () => + sdk.oauth2.create({ + name: generator.guid(), + url: `${keycloakUrl}/realms/myrealm/protocol/openid-connect/token`, + clientId: "my-client", + clientSecret: "my-secret", + method, + grantType, + }) + ) + + const firstToken = await config.doInContext(config.appId, () => + getToken(oauthConfig._id) + ) + await config.doInContext(config.appId, () => + sdk.oauth2.cleanStoredToken(oauthConfig._id) + ) + const secondToken = await config.doInContext(config.appId, () => + getToken(oauthConfig._id) + ) + + expect(firstToken).not.toEqual(secondToken) + }) + it("handles wrong urls", async () => { await expect( config.doInContext(config.appId, async () => { diff --git a/packages/server/src/sdk/app/oauth2/utils.ts b/packages/server/src/sdk/app/oauth2/utils.ts index 993fbe83b1..7817c88d0e 100644 --- a/packages/server/src/sdk/app/oauth2/utils.ts +++ b/packages/server/src/sdk/app/oauth2/utils.ts @@ -62,24 +62,32 @@ const trackUsage = async (id: string) => { }) } -// TODO: check if caching is worth export async function getToken(id: string) { - const config = await get(id) - if (!config) { - throw new HttpError(`oAuth config ${id} count not be found`) - } + const token = await cache.withCacheWithDynamicTTL( + cache.CacheKey.OAUTH2_TOKEN(id), + async () => { + const config = await get(id) + if (!config) { + throw new HttpError(`oAuth config ${id} count not be found`) + } - const resp = await fetchToken(config) + const resp = await fetchToken(config) - const jsonResponse = await resp.json() - if (!resp.ok) { - const message = jsonResponse.error_description ?? resp.statusText + const jsonResponse = await resp.json() + if (!resp.ok) { + const message = jsonResponse.error_description ?? resp.statusText - throw new Error(`Error fetching oauth2 token: ${message}`) - } + throw new Error(`Error fetching oauth2 token: ${message}`) + } + + const token = `${jsonResponse.token_type} ${jsonResponse.access_token}` + const ttl = jsonResponse.expires_in ?? -1 + return { value: token, ttl } + } + ) await trackUsage(id) - return `${jsonResponse.token_type} ${jsonResponse.access_token}` + return token } export async function validateConfig(config: { @@ -131,3 +139,7 @@ export async function getLastUsages(ids: string[]) { }, {}) return result } + +export async function cleanStoredToken(id: string) { + await cache.destroy(cache.CacheKey.OAUTH2_TOKEN(id), { useTenancy: true }) +}