package org.greenstone.gsdl3.auth.oidc;

import org.apache.cxf.rs.security.jose.jwa.SignatureAlgorithm;
import org.apache.cxf.rs.security.jose.jwk.JsonWebKey;
import org.apache.cxf.rs.security.jose.jwk.JsonWebKeys;
import org.apache.cxf.rs.security.jose.jwk.JwkUtils;
import org.apache.cxf.rs.security.jose.jws.JwsHeaders;
import org.apache.cxf.rs.security.jose.jws.PrivateKeyJwsSignatureProvider;
import org.apache.cxf.rs.security.jose.jwt.JoseJwtProducer;
import org.apache.cxf.rs.security.jose.jwt.JwtClaims;
import org.apache.cxf.rs.security.jose.jwt.JwtToken;

import java.io.FileInputStream;
import java.security.Key;
import java.security.KeyStore;
import java.security.PrivateKey;
import java.security.cert.X509Certificate;
import java.util.Collections;
import java.util.Enumeration;
import java.util.Properties;
    
public class CxfKeystoreKeyManager implements Gs3OidcProvider.KeyManager {
  private final String kid;
  private final PrivateKeyJwsSignatureProvider signer;
  private final JsonWebKeys publicJwks;
  private final JoseJwtProducer producer = new JoseJwtProducer();

  public CxfKeystoreKeyManager(OidcConfig cfg) {
    if (!cfg.useKeystore()) {
      throw new IllegalArgumentException("Keystore not configured in OidcConfig");
    }
    try (FileInputStream fis = new FileInputStream(cfg.keystorePath)) {
      KeyStore ks = KeyStore.getInstance(cfg.keystoreType);
      ks.load(fis, toChars(cfg.keystorePassword));

      String alias = cfg.keyAlias != null ? cfg.keyAlias : firstKeyAlias(ks);
      Key k = ks.getKey(alias, toChars(cfg.keyPassword));
      if (!(k instanceof PrivateKey)) {
        throw new IllegalStateException("Alias " + alias + " does not hold a PrivateKey");
      }
      X509Certificate cert = (X509Certificate) ks.getCertificate(alias);
      if (cert == null) {
        throw new IllegalStateException("No certificate for alias " + alias);
      }

      Properties props = new Properties();
      props.setProperty("use", "sig");
      props.setProperty("alg", SignatureAlgorithm.RS256.name());
      this.kid = cfg.kid;      
      JsonWebKey jwk = JwkUtils.fromPublicKey(cert.getPublicKey(), props, kid);
      this.publicJwks = new JsonWebKeys(Collections.singletonList(jwk));

      this.signer = new PrivateKeyJwsSignatureProvider((PrivateKey) k, SignatureAlgorithm.RS256);
    } catch (Exception e) {
      throw new RuntimeException("Failed to load OIDC signing key from keystore", e);
    }
  }

  private static char[] toChars(String s) { return s == null ? null : s.toCharArray(); }

  private static String firstKeyAlias(KeyStore ks) throws Exception {
    Enumeration<String> en = ks.aliases();
    while (en.hasMoreElements()) {
      String a = en.nextElement();
      if (ks.isKeyEntry(a)) return a;
    }
    throw new IllegalStateException("No key entries in keystore");
  }

  @Override public String kid() { return kid; }
    @Override public String jwksJson() { return JwkUtils.jwkSetToJson(publicJwks); }

    

  @Override
  public String signJwt(JwsHeaders headers, JwtClaims claims) throws Exception {
    if (headers == null) headers = new JwsHeaders();
    headers.setAlgorithm(SignatureAlgorithm.RS256.name());
    headers.setKeyId(kid);
    return producer.processJwt(new JwtToken(headers, claims), /* jwe */ null, signer);
  }
}
