JWT with CORS for FastAPI

September 25, 2022

The problem: If you implement it naively, the bearer token you send never reach the server. This is an issue of chicken and egg because if you check the JWT with the preflight request you won't find it, and the browser won't send it as long as the preflight request did not go through. The solution: Write your own middleware. You simply need to check whether you received a preflight request and return the a 200 code with the proper headers.If you except multiple origins, you can even implement your own CORS rules or disable it (shush emoji - this is strongly discouraged for API storing senstive data - they exist for a reason).

If you use FastAPI, you may be familiar with the CORS middleware. However, adding JWT-based authorization to a CORS API is not as easy as adding another middleware.

CORS is the mechanism to make exceptions of the same-origin policy on the browser, and allow a webpage to send requests to APIs hosted in a different domain.

I have found that the easiest way to do JWT authentication (with @auth0) is to create your FastAPI middleware that will handle both CORS and authorization.

To implement the middleware, we have to understand that the browser will send a pre-flight request before every special request, to ask the server whether the subsequent request will be OK in terms of method, origin.

To handle pre-flight requests, we just need to detect it and answer 200 with the correct set of headers.

@app.middleware("http")
async def authorize_request(request: Request, call_next):
    origin = request.headers["Referer"]
    if origin.endswith("/"): origin = origin[:-1]
    cors_headers = {
        "Access-Control-Allow-Methods": "GET, POST, OPTIONS", 
        "Access-Control-Allow-Credentials": "true", 
        "Access-Control-Allow-Origin": origin,
        "Access-Control-Allow-Headers": "Origin, X-Requested-With, Content-Type, Accept, Authorization", 
        "Access-Control-Max-Age": "86400"
    }
    if request.method == "OPTIONS":
        if not ("example.com" in origin):
            print(origin, "not in allowed origins")
            return PlainTextResponse("CORS error", status_code=401)
        return PlainTextResponse(
            "OK", 
            status_code=200, 
            headers=cors_headers
        )

Once this is done, we can treat other requests normally, making sure that the same headers are set

@app.middleware("http")
async def authorize_request(request: Request, call_next):

    # ...handle pre-flight request...
    response = await call_next(request)
    for h in cors_headers:
        response.headers[h] = cors_headers[h]
    return response

Now that CORS is properly set up, we can handle JWT authorization. This is pretty straightforward with Auth0 and we can simply decode the Bearer token after having made sure that we are not treating a pre-flight request.

@app.middleware("http")
async def authorize_request(request: Request, call_next):
    
    # ... handle pre-flight request ...

    if 'Authorization' not in request.headers:
        return JSONResponse(status_code=401, content={'detail': "No credentials found (authorization bearer or cookie)"})
    token = decode_auth_header(request.headers['Authorization'])
    payload = get_token_payload(token)
    if 'https://api.example.com/' not in payload['aud']:
        raise Exception("unauthorized")

    # ... handle request and respond

To decode the auth header, we simply need to remove the Bearer prefix, whereas to get the token payload we will need to use the jose library and fetch RSA keys from Auth0.

def get_token_payload(token: str):
    unverified_header = jwt.get_unverified_header(token)
    rsa_keys = requests.get(f'https://{AUTH0_DOMAIN}/.well-known/jwks.json').json()
    rsa_keys = {k['kid']: k for k in rsa_keys['keys']}
    if unverified_header['alg'] not in ALGORITHMS:
        raise Exception("Unsupported JWT algorithms")
    payload = jwt.decode(
        token,
        rsa_keys[unverified_header['kid']],
        algorithms=unverified_header['alg'],
        audience=API_AUDIENCE,
        issuer=f"https://{AUTH0_DOMAIN}/"
    )
    return payload

There you have it, with this you should be able to set up JWT-based authorization on cross-origin endpoints. Full code. Original Twitter thread


Written by Frédéric Gessler Deep learning, hardware accelerators and compilers enthusiast. Hacker. Builder. EPFL and Amazon alumni