package org.greenstone.gsdl3.auth.oidc;

import org.apache.cxf.rs.security.jose.jws.JwsHeaders;
import org.apache.cxf.rs.security.jose.jwt.JwtClaims;

import javax.servlet.http.*;
import java.io.IOException;
import java.io.PrintWriter;
import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
import java.security.SecureRandom;
import java.time.Duration;
import java.time.Instant;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;

public class GS3OidcProvider {

  public interface SessionAccessor {
    String currentUserId(HttpServletRequest req);
    default void redirectToLogin(HttpServletRequest req, HttpServletResponse resp) throws IOException {
      renderDevLogin(req, resp);
    }
  }
  public interface UserDirectory {
    String emailFor(String userId);
    String displayNameFor(String userId);
  }
  public interface AuthCodeStore {
    void save(AuthCode code);
    AuthCode take(String code);
  }
  public interface AccessTokenStore {
    void save(AccessToken tok);
    AccessToken find(String token);
  }
  public interface KeyManager {
    String kid();
    String jwksJson();
    String signJwt(JwsHeaders headers, JwtClaims claims) throws Exception;
  }

  public static final class AuthCode {
    public String code;
    public String clientId;
    public String redirectUri;
    public String userId;
    public String scope;
    public String nonce;
    public String codeChallenge;
    public Instant expiresAt;
  }
  public static final class AccessToken {
    public String token;
    public String userId;
    public String scope;
    public Instant expiresAt;
  }

  public static class Builder {
    private String issuer;
    private String basePath;
    private Map<String,String> clients = new HashMap<String,String>();
    private SessionAccessor sessionAccessor;
    private UserDirectory userDirectory;
    private AuthCodeStore authCodeStore;
    private AccessTokenStore accessTokenStore;
    private KeyManager keyManager;
    private Set<String> allowedScopes = new LinkedHashSet<String>(Arrays.asList("openid","email","profile"));

    public Builder issuer(String iss) { this.issuer = iss; return this; }
    public Builder basePath(String base) { this.basePath = base; return this; }
    public Builder clients(Map<String,String> c) { this.clients = new HashMap<String,String>(c); return this; }
    public Builder sessionAccessor(SessionAccessor s) { this.sessionAccessor = s; return this; }
    public Builder userDirectory(UserDirectory d) { this.userDirectory = d; return this; }
    public Builder authCodeStore(AuthCodeStore s) { this.authCodeStore = s; return this; }
    public Builder accessTokenStore(AccessTokenStore s) { this.accessTokenStore = s; return this; }
    public Builder keyManager(KeyManager k) { this.keyManager = k; return this; }
    public Builder allowedScopes(Collection<String> scopes) {
      if (scopes != null && !scopes.isEmpty()) {
        LinkedHashSet<String> s = new LinkedHashSet<String>();
        for (String v : scopes) if (v != null && !v.trim().isEmpty()) s.add(v.trim());
        if (!s.isEmpty()) this.allowedScopes = s;
      }
      return this;
    }

    public GS3OidcProvider build() {
      if (issuer == null) throw new IllegalArgumentException("issuer required");
      if (sessionAccessor == null) sessionAccessor = new DevSessionAccessor();
      if (userDirectory == null)   userDirectory = new DevUserDirectory();
      if (authCodeStore == null)   authCodeStore = new InMemoryAuthCodeStore();
      if (accessTokenStore == null)accessTokenStore = new InMemoryAccessTokenStore();
      if (keyManager == null)      keyManager = new CxfKeyManager();
      return new GS3OidcProvider(issuer, basePath, clients, sessionAccessor, userDirectory,
                                 authCodeStore, accessTokenStore, keyManager, allowedScopes);
    }
  }

  private final String issuer;
  private final String basePath;
  private final Map<String,String> clients;
  private final SessionAccessor sessionAccessor;
  private final UserDirectory userDirectory;
  private final AuthCodeStore authCodes;
  private final AccessTokenStore tokens;
  private final KeyManager keys;
  private final Set<String> allowedScopes;

  private static final SecureRandom RNG = new SecureRandom();

  private GS3OidcProvider(String issuer, String basePath, Map<String, String> clients,
                          SessionAccessor sessionAccessor, UserDirectory userDirectory,
                          AuthCodeStore authCodes, AccessTokenStore tokens, KeyManager keys,
                          Set<String> allowedScopes) {
    this.issuer = stripTrailingSlash(issuer);
    this.basePath = basePath == null ? "" : stripTrailingSlash(basePath);
    this.clients = Collections.unmodifiableMap(new HashMap<String,String>(clients));
    this.sessionAccessor = sessionAccessor;
    this.userDirectory = userDirectory;
    this.authCodes = authCodes;
    this.tokens = tokens;
    this.keys = keys;
    this.allowedScopes = Collections.unmodifiableSet(new LinkedHashSet<String>(allowedScopes));
  }



    // For errors during the authorization endpoint (/authorize):
    private void authzError(HttpServletResponse resp,
			    String redirectUri,
			    String state,
			    String error) throws IOException {
	if (redirectUri != null) {
	    // Redirect back with error and optional state
	    StringBuilder loc = new StringBuilder();
	    loc.append(redirectUri).append("?error=").append(url(error));
	    if (state != null) loc.append("&state=").append(url(state));
	    resp.setStatus(302);
	    resp.setHeader("Location", loc.toString());
	} else {
	    // No redirect URI → return JSON
	    jsonError(resp, 400, error);
	}
    }
    
    // For token endpoint or other direct JSON responses
    private static void jsonError(HttpServletResponse resp,
				  int status,
				  String error) throws IOException {
	resp.setStatus(status);
	resp.setContentType("application/json");
	resp.setCharacterEncoding("UTF-8");
	resp.setHeader("Cache-Control", "no-store");
	
	String json = "{\"error\":\"" + escJ(error) + "\"}";
	resp.getWriter().write(json);
    }
    
    
  public void handleJwks(HttpServletResponse resp) throws IOException {
    resp.setContentType("application/json");
    resp.setCharacterEncoding("UTF-8");
    resp.setHeader("Cache-Control", "public, max-age=3600");
    PrintWriter out = resp.getWriter();
    out.write(keys.jwksJson());
    out.flush();
  }

  public void handleAuthorize(HttpServletRequest req, HttpServletResponse resp) throws IOException {
    String responseType = req.getParameter("response_type");
    String clientId     = req.getParameter("client_id");
    String redirectUri  = req.getParameter("redirect_uri");
    String scope        = req.getParameter("scope");
    String state        = req.getParameter("state");
    String nonce        = req.getParameter("nonce");
    String codeChallenge= req.getParameter("code_challenge");
    String codeMethod   = req.getParameter("code_challenge_method");

    if (!"code".equals(responseType))                 { authzError(resp, redirectUri, state, "unsupported_response_type"); return; }
    if (clientId == null || redirectUri == null)      { authzError(resp, redirectUri, state, "invalid_request"); return; }
    String regRedirect = clients.get(clientId);
    if (regRedirect == null || !regRedirect.equals(redirectUri)) { authzError(resp, redirectUri, state, "unauthorized_client"); return; }

    if (scope == null || scope.trim().isEmpty()) { authzError(resp, redirectUri, state, "invalid_scope"); return; }
    String[] requested = scope.trim().split("\\s+");
    Set<String> reqSet = new LinkedHashSet<String>(Arrays.asList(requested));
    if (!reqSet.contains("openid")) { authzError(resp, redirectUri, state, "invalid_scope"); return; }
    if (!allowedScopes.containsAll(reqSet)) { authzError(resp, redirectUri, state, "invalid_scope"); return; }

    if (!"S256".equalsIgnoreCase(codeMethod))         { authzError(resp, redirectUri, state, "invalid_request"); return; }
    if (isBlank(nonce) || isBlank(codeChallenge) || codeChallenge.length() < 43) { authzError(resp, redirectUri, state, "invalid_request"); return; }

    String userId = sessionAccessor.currentUserId(req);
    if (userId == null) { sessionAccessor.redirectToLogin(req, resp); return; }

    AuthCode ac = new AuthCode();
    ac.code          = randomUrlSafe(32);
    ac.clientId      = clientId;
    ac.redirectUri   = redirectUri;
    ac.userId        = userId;
    ac.scope         = scope;
    ac.nonce         = nonce;
    ac.codeChallenge = codeChallenge;
    ac.expiresAt     = Instant.now().plusSeconds(120);
    authCodes.save(ac);

    String loc = redirectUri + "?code=" + url(ac.code) + (state != null ? "&state=" + url(state) : "");
    resp.setStatus(302);
    resp.setHeader("Location", loc);
  }

  public void handleToken(HttpServletRequest req, HttpServletResponse resp) throws IOException {
    req.setCharacterEncoding("UTF-8");
    String grantType   = req.getParameter("grant_type");
    String code        = req.getParameter("code");
    String redirectUri = req.getParameter("redirect_uri");
    String clientId    = req.getParameter("client_id");
    String codeVerifier= req.getParameter("code_verifier");

    if (!"authorization_code".equals(grantType)) { jsonError(resp, 400, "unsupported_grant_type"); return; }

    AuthCode ac = authCodes.take(code);
    if (ac == null)                                   { jsonError(resp, 400, "invalid_grant"); return; }
    if (ac.expiresAt.isBefore(Instant.now()))         { jsonError(resp, 400, "invalid_grant"); return; }
    if (!Objects.equals(ac.clientId, clientId))       { jsonError(resp, 400, "invalid_grant"); return; }
    if (!Objects.equals(ac.redirectUri, redirectUri)) { jsonError(resp, 400, "invalid_grant"); return; }
    if (!pkceValidS256(ac.codeChallenge, codeVerifier)) { jsonError(resp, 400, "invalid_grant"); return; }

    long nowSec = System.currentTimeMillis() / 1000L;
    JwtClaims claims = new JwtClaims();
    claims.setIssuer(issuer());
    claims.setSubject(stableSubFor(ac.userId));
    claims.setAudience(clientId);
    claims.setIssuedAt(nowSec);
    claims.setExpiryTime(nowSec + 600);
    claims.setClaim("nonce", ac.nonce);
    if (ac.scope.contains("email"))   claims.setClaim("email", userDirectory.emailFor(ac.userId));
    if (ac.scope.contains("profile")) claims.setClaim("name",  userDirectory.displayNameFor(ac.userId));

    String idToken;
    try {
      idToken = keys.signJwt(new JwsHeaders(), claims);
    } catch (Exception e) {
      jsonError(resp, 500, "server_error"); return;
    }

    AccessToken tok = new AccessToken();
    tok.token     = randomUrlSafe(32);
    tok.userId    = ac.userId;
    tok.scope     = ac.scope;
    tok.expiresAt = Instant.ofEpochSecond(nowSec).plus(Duration.ofMinutes(30));
    tokens.save(tok);

    Map<String,Object> out = new LinkedHashMap<String,Object>();
    out.put("access_token", tok.token);
    out.put("token_type", "Bearer");
    out.put("expires_in", 600);
    out.put("id_token", idToken);
    writeJson(resp, out, true);
  }

  public void handleUserInfo(HttpServletRequest req, HttpServletResponse resp) throws IOException {
    String authz = req.getHeader("Authorization");
    if (authz == null || !authz.startsWith("Bearer ")) { resp.setStatus(401); return; }
    String tokenVal = authz.substring("Bearer ".length()).trim();
    AccessToken t = tokens.find(tokenVal);
    if (t == null || t.expiresAt.isBefore(Instant.now())) { resp.setStatus(401); return; }

    Map<String,Object> claims = new LinkedHashMap<String,Object>();
    claims.put("sub", stableSubFor(t.userId));
    if (t.scope.contains("email"))   claims.put("email", userDirectory.emailFor(t.userId));
    if (t.scope.contains("profile")) claims.put("name",  userDirectory.displayNameFor(t.userId));
    writeJson(resp, claims, true);
  }

  private String issuer() { return issuer; }
  private static String stripTrailingSlash(String s) { return (s != null && s.endsWith("/")) ? s.substring(0, s.length()-1) : s; }
  private static boolean isBlank(String s) { return s == null || s.trim().isEmpty(); }
  private static String url(String s) { return URLEncoder.encode(s == null ? "" : s, StandardCharsets.UTF_8); }
  private static String randomUrlSafe(int bytes) {
    byte[] b = new byte[bytes]; RNG.nextBytes(b);
    return Base64.getUrlEncoder().withoutPadding().encodeToString(b);
  }
  private static boolean pkceValidS256(String codeChallenge, String codeVerifier) {
    if (isBlank(codeVerifier)) return false;
    try {
      java.security.MessageDigest md = java.security.MessageDigest.getInstance("SHA-256");
      byte[] hash = md.digest(codeVerifier.getBytes(StandardCharsets.US_ASCII));
      String derived = Base64.getUrlEncoder().withoutPadding().encodeToString(hash);
      return codeChallenge.equals(derived);
    } catch (Exception e) { return false; }
  }
  private static String stableSubFor(String userId) {
    try {
      java.security.MessageDigest md = java.security.MessageDigest.getInstance("SHA-256");
      byte[] h = md.digest(("gs3|" + userId).getBytes(java.nio.charset.StandardCharsets.UTF_8));
      return "gs3_" + Base64.getUrlEncoder().withoutPadding().encodeToString(h).substring(0, 22);
    } catch (Exception e) { return "gs3_" + userId; }
  }
  private static void writeJson(HttpServletResponse resp, Map<String, ?> map, boolean noStore) throws IOException {
    resp.setContentType("application/json");
    resp.setCharacterEncoding("UTF-8");
    if (noStore) resp.setHeader("Cache-Control", "no-store");
    String json = toJson(map);
    PrintWriter out = resp.getWriter();
    out.write(json);
    out.flush();
  }
  private static String toJson(Map<String, ?> map) {
    StringBuilder sb = new StringBuilder();
    sb.append("{");
    boolean first = true;
    for (Map.Entry<String, ?> e : map.entrySet()) {
      if (!first) sb.append(",");
      first = false;
      sb.append("\"").append(escJ(e.getKey())).append("\":");
      Object v = e.getValue();
      if (v == null) sb.append("null");
      else if (v instanceof Number || v instanceof Boolean) sb.append(v.toString());
      else if (v instanceof Collection) {
        sb.append("[");
        boolean f2 = true;
        for (Object o : (Collection<?>) v) {
          if (!f2) sb.append(",");
          f2 = false;
          sb.append("\"").append(escJ(String.valueOf(o))).append("\"");
        }
        sb.append("]");
      } else {
        sb.append("\"").append(escJ(String.valueOf(v))).append("\"");
      }
    }
    sb.append("}");
    return sb.toString();
  }
  private static String escJ(String s) {
    return s.replace("\\","\\\\").replace("\"","\\\"")
            .replace("\n","\\n").replace("\r","\\r").replace("\t","\\t");
  }

  public static class DevSessionAccessor implements SessionAccessor {
    @Override public String currentUserId(HttpServletRequest req) {
      HttpSession s = req.getSession(false);
      return (s != null) ? (String) s.getAttribute("gs3_user") : null;
    }
  }
  public static class DevUserDirectory implements UserDirectory {
    @Override public String emailFor(String userId) { return userId + "@example.invalid"; }
    @Override public String displayNameFor(String userId) { return userId; }
  }
  public static class InMemoryAuthCodeStore implements AuthCodeStore {
    private final Map<String, AuthCode> map = new ConcurrentHashMap<String, AuthCode>();
    @Override public void save(AuthCode code) { map.put(code.code, code); }
    @Override public AuthCode take(String code) { return map.remove(code); }
  }
  public static class InMemoryAccessTokenStore implements AccessTokenStore {
    private final Map<String, AccessToken> map = new ConcurrentHashMap<String, AccessToken>();
    @Override public void save(AccessToken tok) { map.put(tok.token, tok); }
    @Override public AccessToken find(String token) { return map.get(token); }
  }

  static void renderDevLogin(HttpServletRequest req, HttpServletResponse resp) throws IOException {
    resp.setContentType("text/html;charset=UTF-8");
    PrintWriter out = resp.getWriter();
    out.println("<html><body><h3>GS3 Login (dev)</h3>");
    out.println("<form method='POST' action=''>");
    out.println("<input type='hidden' name='__dev_login' value='1'/>");
    Enumeration<String> names = req.getParameterNames();
    while (names.hasMoreElements()) {
      String k = names.nextElement();
      if (k.equals("username") || k.equals("password") || k.equals("__dev_login")) continue;
      String[] vals = req.getParameterValues(k);
      if (vals != null) {
        for (String v : vals) {
          out.println("<input type='hidden' name='"+k+"' value='"+escHtml(v)+"'/>");
        }
      }
    }
    out.println("Username: <input name='username'/><br/>");
    out.println("Password: <input name='password' type='password'/><br/>");
    out.println("<button type='submit'>Sign in</button>");
    out.println("</form></body></html>");
  }
  private static String escHtml(String s) {
    if (s == null) return "";
    return s.replace("&","&amp;").replace("<","&lt;").replace(">","&gt;").replace("\"","&quot;");
  }
}
