/* * Copyright 2018 Okta, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.okta.authn.sdk.example; import com.okta.sdk.lang.Strings; import org.apache.shiro.web.servlet.OncePerRequestFilter; import org.apache.shiro.web.util.WebUtils; import javax.servlet.FilterChain; import javax.servlet.ServletException; import javax.servlet.ServletRequest; import javax.servlet.ServletResponse; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpSession; import java.io.IOException; import java.util.Locale; import java.util.UUID; public class OverlySimpleCsrfFilter extends OncePerRequestFilter { private static final String CSRF_KEY = "_csrf"; private boolean shouldFilter(ServletRequest request) { HttpServletRequest httpRequest = WebUtils.toHttp(request); String method = httpRequest.getMethod().toUpperCase(Locale.ENGLISH); return "POST".equals(method) || "PUT".equals(method); // POST or PUT } @Override protected void doFilterInternal(ServletRequest request, ServletResponse response, FilterChain chain) throws ServletException, IOException { HttpSession session = WebUtils.toHttp(request).getSession(true); String expectedCsrf = (String) session.getAttribute(CSRF_KEY); // figure out the next CSRF token String nextCSRF = UUID.randomUUID().toString(); request.setAttribute(CSRF_KEY, nextCSRF); if (shouldFilter(request)) { String actualCsrf = request.getParameter(CSRF_KEY); // if the csrf token does not match stop processing the filter if (Strings.isEmpty(expectedCsrf) || !expectedCsrf.equals(actualCsrf)) { request.getServletContext().log("CSRF token did not match"); WebUtils.toHttp(response).sendError(HttpServletResponse.SC_BAD_REQUEST); return; } } chain.doFilter(request, response); // next key session.setAttribute(CSRF_KEY, nextCSRF); } }