import java.io.*;
import java.io.BufferedReader;

import java.util.Base64;
import java.util.Base64.*;

import java.security.Security;
import java.security.KeyFactory;
import java.security.interfaces.*;
import java.security.spec.PKCS8EncodedKeySpec;
import java.security.cert.X509Certificate;
import java.security.cert.CertificateFactory;
import java.security.interfaces.RSAPrivateKey;

import com.nimbusds.jose.*;
import com.nimbusds.jose.crypto.*;
import com.nimbusds.jose.crypto.bc.BouncyCastleProviderSingleton;
import com.nimbusds.jose.jwk.RSAKey;
import com.nimbusds.jose.JWEAlgorithm;
import com.nimbusds.jwt.EncryptedJWT;

class PGW_JWT_TEST {
	private String keyDir,rsaPubKeyFileName,rsaPrivateKeyFileName;

    public PGW_JWT_TEST() {
		keyDir = System.getProperty("user.dir") + File.separator + "test-keys" + File.separator;
		rsaPubKeyFileName = keyDir + "sandbox-jwt-2c2p.demo.2.1.public.cer";
		rsaPrivateKeyFileName = keyDir + "test_sign_pri.pem";
	}

    public static void main(String[] args) throws Exception {
        String output = "error:";
        try {
            String method = "";
            String data = "";
            if (args.length >= 2 && !args[0].isEmpty()) {
                method = args[0];
            }
            if (args.length >= 2 && !args[1].isEmpty()) {
                data = args[1];
            }
            if (!method.isEmpty() && !data.isEmpty()) {
                PGW_JWT_TEST pgwJwt = new PGW_JWT_TEST();
                switch (method) {
                    case "decryptJWEToken":
                        output = pgwJwt.decryptJWEToken(data);
                        break;
                    case "createJWEToken":
                        output = pgwJwt.createJWEToken(data);
                        break;
                    default:
                        output = output + "no_such_method_exist";
                        break;
                }
            }

        } catch (Exception ex) {
            output = output + ex.getMessage();
        } finally {
            System.out.print(output);
        }
    }

    private RSAPublicKey getRSAPublickey(String filename) throws Exception {
        FileInputStream is = new FileInputStream(filename);
        CertificateFactory certFactory = CertificateFactory.getInstance("X509");
        X509Certificate jwePubKey = (X509Certificate) certFactory.generateCertificate(is);
        RSAKey rsaJWE = RSAKey.parse(jwePubKey);
        RSAPublicKey jweRsaPubKey = rsaJWE.toRSAPublicKey();
        return jweRsaPubKey;
    }

    private RSAPrivateKey geRSAPrivateKey(String filename) throws Exception {
        String privateKeyPEM = "";
        BufferedReader br = new BufferedReader(new FileReader(filename));
        String line;
        while ((line = br.readLine()) != null) {
            privateKeyPEM += line + "\n";
        }
        br.close();
        privateKeyPEM = privateKeyPEM.replace("-----BEGIN PRIVATE KEY-----\n", "");
        privateKeyPEM = privateKeyPEM.replace("-----END PRIVATE KEY-----", "");
        byte[] encoded = Base64.getMimeDecoder().decode(privateKeyPEM);
        KeyFactory kf = KeyFactory.getInstance("RSA");
        PKCS8EncodedKeySpec keySpec = new PKCS8EncodedKeySpec(encoded);
        RSAPrivateKey jwsPrivateKey = (RSAPrivateKey) kf.generatePrivate(keySpec);
        return jwsPrivateKey;
    }

    private String createEpayload(String paymentRequest, RSAPublicKey jweRsaPubKey) throws Exception {
        JWEAlgorithm alg = JWEAlgorithm.RSA_OAEP;
        EncryptionMethod enc = EncryptionMethod.A256GCM;
        JWEObject jwe = new JWEObject(new JWEHeader(alg, enc), new Payload(paymentRequest));
        jwe.encrypt(new RSAEncrypter(jweRsaPubKey));
        String jwePayload = jwe.serialize();
        return jwePayload;
    }

    private String createRSASign(String jwePayload, RSAPrivateKey jwsPrivateKey) throws Exception {
        Security.addProvider(BouncyCastleProviderSingleton.getInstance());
        RSASSASigner signer = new RSASSASigner(jwsPrivateKey);
        JWSHeader header = new JWSHeader(JWSAlgorithm.PS256, JOSEObjectType.JWT, null, null, null, null, null, null, null, null, null, null, null);
        JWSObject jwsObject = new JWSObject(header, new Payload(jwePayload));
        jwsObject.sign(signer);
        String jwsPayload = jwsObject.serialize();
        return jwsPayload;
    }

    private boolean verifyRSASign(JWSObject jwsObject, RSAPublicKey jweRsaPubKey) throws Exception {
        Security.addProvider(BouncyCastleProviderSingleton.getInstance());
        boolean verified = jwsObject.verify(new RSASSAVerifier(jweRsaPubKey)); //return true represent valid JWS, else invalid.
        return verified;
    }

    private String decryptJWEPayload(JWSObject jwsObject, RSAPrivateKey jwsPrivateKey) throws Exception {
        JWEObject jwe = EncryptedJWT.parse(jwsObject.getPayload().toString());
        jwe.decrypt(new RSADecrypter(jwsPrivateKey));
        String responsePayload = jwe.getPayload().toString();
        return responsePayload;
    }

    protected String createJWEToken(String paymentRequest) throws Exception {
        RSAPublicKey jweRsaPubKey = getRSAPublickey(rsaPubKeyFileName);
        RSAPrivateKey jwsPrivateKey = geRSAPrivateKey(rsaPrivateKeyFileName);
        String jwePayload = createEpayload(paymentRequest, jweRsaPubKey);
        String jwsPayload = createRSASign(jwePayload, jwsPrivateKey);
        return jwsPayload;
    }

    protected String decryptJWEToken(String jwsResponse) throws Exception {
        JWSObject jwsObject = JWSObject.parse(jwsResponse);
        RSAPublicKey jweRsaPubKey = getRSAPublickey(rsaPubKeyFileName);
        RSAPrivateKey jwsPrivateKey = geRSAPrivateKey(rsaPrivateKeyFileName);
        boolean verified = verifyRSASign(jwsObject, jweRsaPubKey);
        String responsePayload = "error:sign_unverified";
        if (verified == true) {
            responsePayload = decryptJWEPayload(jwsObject, jwsPrivateKey);
        }
        return responsePayload;
    }
}
