package org.greenstone.gsdl3.auth.oidc;

import org.apache.cxf.rs.security.jose.jwa.SignatureAlgorithm;
import org.apache.cxf.rs.security.jose.jwk.JsonWebKey;
import org.apache.cxf.rs.security.jose.jwk.JsonWebKeys;
import org.apache.cxf.rs.security.jose.jwk.JwkUtils;
import org.apache.cxf.rs.security.jose.jws.JwsHeaders;
import org.apache.cxf.rs.security.jose.jws.PrivateKeyJwsSignatureProvider;
import org.apache.cxf.rs.security.jose.jwt.JoseJwtProducer;
import org.apache.cxf.rs.security.jose.jwt.JwtClaims;
import org.apache.cxf.rs.security.jose.jwt.JwtToken;

import java.security.KeyPair;
import java.security.KeyPairGenerator;
import java.security.interfaces.RSAPrivateKey;
import java.security.interfaces.RSAPublicKey;
import java.util.Collections;

public class CxfKeyManager implements Gs3OidcProvider.KeyManager {
  private final KeyPair kp;
  private final String kid = "oidc-dev-key";
  private final JsonWebKeys publicJwks;
  private final PrivateKeyJwsSignatureProvider signer;
  private final JoseJwtProducer producer = new JoseJwtProducer();

  public CxfKeyManager() {
    try {
      KeyPairGenerator kpg = KeyPairGenerator.getInstance("RSA");
      kpg.initialize(2048);
      this.kp = kpg.generateKeyPair();

      RSAPublicKey pub = (RSAPublicKey) kp.getPublic();
      JsonWebKey jwk = JwkUtils.fromRSAPublicKey(pub, "sig", SignatureAlgorithm.RS256.name());
      jwk.setKeyId(kid);
      this.publicJwks = new JsonWebKeys(Collections.singletonList(jwk));

      this.signer = new PrivateKeyJwsSignatureProvider((RSAPrivateKey) kp.getPrivate(), SignatureAlgorithm.RS256);
    } catch (Exception e) {
      throw new RuntimeException("Failed to initialize CXF KeyManager", e);
    }
  }

  @Override public String kid() { return kid; }
    @Override public String jwksJson() { return JwkUtils.jwkSetToJson(publicJwks); }

  @Override
  public String signJwt(JwsHeaders headers, JwtClaims claims) throws Exception {
    if (headers == null) headers = new JwsHeaders();
    headers.setAlgorithm(SignatureAlgorithm.RS256.name());
    headers.setKeyId(kid);
    return producer.processJwt(new JwtToken(headers, claims), /* jwe */ null, signer);
  }
}
