1+ /*
2+ * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+ *
4+ * Licensed under the Apache License, Version 2.0 (the "License").
5+ * You may not use this file except in compliance with the License.
6+ * A copy of the License is located at
7+ *
8+ * http://aws.amazon.com/apache2.0
9+ *
10+ * or in the "license" file accompanying this file. This file is distributed
11+ * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
12+ * express or implied. See the License for the specific language governing
13+ * permissions and limitations under the License.
14+ */
15+
16+ package software .amazon .awssdk .services .signin .internal ;
17+
18+ import java .math .BigInteger ;
19+ import java .nio .BufferUnderflowException ;
20+ import java .nio .ByteBuffer ;
21+ import java .security .AlgorithmParameters ;
22+ import java .security .KeyFactory ;
23+ import java .security .NoSuchAlgorithmException ;
24+ import java .security .interfaces .ECPrivateKey ;
25+ import java .security .interfaces .ECPublicKey ;
26+ import java .security .spec .ECGenParameterSpec ;
27+ import java .security .spec .ECParameterSpec ;
28+ import java .security .spec .ECPoint ;
29+ import java .security .spec .ECPrivateKeySpec ;
30+ import java .security .spec .ECPublicKeySpec ;
31+ import java .security .spec .InvalidKeySpecException ;
32+ import java .security .spec .InvalidParameterSpecException ;
33+ import java .util .Arrays ;
34+ import java .util .Base64 ;
35+ import software .amazon .awssdk .annotations .SdkInternalApi ;
36+ import software .amazon .awssdk .utils .Pair ;
37+
38+ @ SdkInternalApi
39+ public final class EcKeyLoader {
40+
41+ private static final String SECP_256_R1_STD_NAME = "secp256r1" ;
42+
43+ private static final byte DER_SEQUENCE_TAG = 0x30 ;
44+ private static final byte DER_INTEGER_TAG = 0x02 ;
45+ private static final byte DER_OCTET_STRING_TAG = 0x04 ;
46+ private static final byte DER_BIT_STRING_TAG = 0x03 ;
47+ private static final byte DER_OPTIONAL_SEQ_PARAM_0 = (byte ) 0xA0 ;
48+ private static final byte DER_OPTIONAL_SEQ_PARAM_1 = (byte ) 0xA1 ;
49+ private static final byte DER_OBJECT_IDENTIFIER_TAG = 0x06 ;
50+
51+ private static final int SEC1_VERSION = 1 ;
52+
53+ // bytes for "1.2.840.10045.3.1.7" - the OID for secp256r1 aka prime256v1/NIST P-256
54+ private static byte [] SECP_256_R1_OID_BYTES = new byte [] {0x2A , (byte ) 0x86 , 0x48 , (byte ) 0xCE , 0x3D , 0x03 , 0x01 , 0x07 };
55+
56+ private EcKeyLoader () {
57+ }
58+
59+ /**
60+ * Load ECPrivateKey and ECPublicKey from a SEC1 / RFC 5915 ASN.1 formated PEM.
61+ * <p>
62+ * The only supported curve is: secp256r1.
63+ *
64+ * @param pem EC1 / RFC 5915 ASN.1 formated PEM contents
65+ * @return The ECPrivateKey and ECPublicKey
66+ */
67+ public static Pair <ECPrivateKey , ECPublicKey > loadSec1Pem (String pem ) {
68+ try {
69+ byte [] sec1Der = pemToDer (pem );
70+ ParsedEcKey parsed = parseSec1 (sec1Der );
71+ if (parsed .curveOid == null ) {
72+ throw new IllegalArgumentException ("Missing EC Curve OID" );
73+ }
74+ ECParameterSpec params = curveFromOid (parsed .curveOid );
75+
76+ // Create an ECPrivateKey from the parsed privateScalar value and the EC Curve (EC Parameters)
77+ ECPrivateKey privateKey = (ECPrivateKey ) KeyFactory
78+ .getInstance ("EC" )
79+ .generatePrivate (new ECPrivateKeySpec (parsed .privateScalar , params ));
80+
81+ // create an ECPublicKey from the public bytes
82+ if (parsed .publicBytes == null ) {
83+ throw new IllegalArgumentException ("Invalid certificate - public key is required." );
84+ }
85+ ECPublicKey publicKey = derivePublicFromBytes (parsed .publicBytes , privateKey .getParams ());
86+
87+ return Pair .of (privateKey , publicKey );
88+ } catch (NoSuchAlgorithmException | InvalidParameterSpecException | InvalidKeySpecException e ) {
89+ throw new RuntimeException (e );
90+ }
91+ }
92+
93+ // we only support one algorithm/curve: secp256r1, validate that the oid we have matches that and then build the curve
94+ private static ECParameterSpec curveFromOid (byte [] oid ) throws NoSuchAlgorithmException , InvalidParameterSpecException {
95+ if (Arrays .equals (SECP_256_R1_OID_BYTES , oid )) {
96+ AlgorithmParameters parameters = null ;
97+ parameters = AlgorithmParameters .getInstance ("EC" );
98+ parameters .init (new ECGenParameterSpec (SECP_256_R1_STD_NAME ));
99+ return parameters .getParameterSpec (ECParameterSpec .class );
100+ }
101+ throw new IllegalArgumentException ("Unsupported curve OID: " + Arrays .toString (oid ));
102+ }
103+
104+ // the public key is an octet string of the public X,Y with fixed lengths
105+ private static ECPublicKey derivePublicFromBytes (byte [] raw , ECParameterSpec params ) throws NoSuchAlgorithmException ,
106+ InvalidKeySpecException {
107+ if (raw [0 ] != DER_OCTET_STRING_TAG ) {
108+ throw new IllegalArgumentException ("Expected uncompressed point" );
109+ }
110+ int len = (raw .length - 1 ) / 2 ;
111+ BigInteger x = new BigInteger (1 , java .util .Arrays .copyOfRange (raw , 1 , 1 + len ));
112+ BigInteger y = new BigInteger (1 , java .util .Arrays .copyOfRange (raw , 1 + len , 1 + 2 * len ));
113+ ECPoint w = new ECPoint (x , y );
114+ ECPublicKeySpec spec = new ECPublicKeySpec (w , params );
115+ return (ECPublicKey ) KeyFactory .getInstance ("EC" ).generatePublic (spec );
116+ }
117+
118+ private static class ParsedEcKey {
119+ BigInteger privateScalar ;
120+ byte [] curveOid ;
121+ byte [] publicBytes ;
122+ }
123+
124+
125+ /**
126+ * Follows the SEC1 / RFC 5915 ASN.1 format: PrivateKeyInfo ::= SEQUENCE { version INTEGER (0), privateKeyAlgorithm
127+ * AlgorithmIdentifier, -- ecPublicKey + curve OID privateKey OCTET STRING -- contains the SEC1 DER parameters [0]
128+ * ECParameters {{ NamedCurve }} OPTIONAL, publicKey [1] BIT STRING OPTIONAL }
129+ * <p>
130+ * See: <a href="https://datatracker.ietf.org/doc/html/rfc5915#appendix-A">RFC 5915 - ASIN.1 format</a>
131+ *
132+ * @param der - asn.1 DER representing an EC private key with public key.
133+ * @return the parsed EC key, including the public key bytes.
134+ */
135+ private static ParsedEcKey parseSec1 (byte [] der ) {
136+ ParsedEcKey result = new ParsedEcKey ();
137+ ByteBuffer buffer = ByteBuffer .wrap (der );
138+ int len ;
139+ try {
140+ if (buffer .get () != DER_SEQUENCE_TAG ) {
141+ throw new IllegalArgumentException (
142+ "Invalid SEC1 Private Key: Not a SEQUENCE" );
143+ }
144+ readLength (buffer );
145+
146+ // validate the version
147+ if (buffer .get () != DER_INTEGER_TAG ) {
148+ throw new IllegalArgumentException (
149+ "Invalid SEC1 Private Key: Expected INTEGER" );
150+ }
151+ len = readLength (buffer );
152+ if (len != 1 || buffer .get () != SEC1_VERSION ) {
153+ throw new IllegalArgumentException ("Invalid SEC1 Private Key: invalid version" );
154+ }
155+
156+ // read private key
157+ if (buffer .get () != DER_OCTET_STRING_TAG ) {
158+ throw new IllegalArgumentException (
159+ "Invalid SEC1 Private Key: Expected OCTET STRING" );
160+ }
161+ len = readLength (buffer );
162+
163+ byte [] privateKeyBytes = new byte [len ];
164+ buffer .get (privateKeyBytes );
165+ result .privateScalar = new BigInteger (1 , privateKeyBytes );
166+
167+ while (buffer .hasRemaining ()) {
168+ byte tag = buffer .get ();
169+ len = readLength (buffer );
170+ if (tag == DER_OPTIONAL_SEQ_PARAM_0 ) { // [0] parameters (curve OID)
171+ if (buffer .get () != DER_OBJECT_IDENTIFIER_TAG ) {
172+ throw new IllegalArgumentException (
173+ "Invalid SEC1 Private Key: Expected OID" );
174+ }
175+ int oidLen = readLength (buffer );
176+ byte [] oid = new byte [oidLen ];
177+ buffer .get (oid );
178+ result .curveOid = oid ;
179+ } else if (tag == DER_OPTIONAL_SEQ_PARAM_1 ) { // [1] parameters public key (BIT STRING)
180+ byte bitTag = buffer .get ();
181+ if (bitTag != DER_BIT_STRING_TAG ) {
182+ throw new IllegalArgumentException (
183+ "Invalid SEC1 Private Key: Expected BIT STRING" );
184+ }
185+ int bitLen = readLength (buffer );
186+ byte [] bitString = new byte [bitLen ];
187+ buffer .get (bitString );
188+ // First byte of BIT STRING is the unused bits count, skip it
189+ result .publicBytes = java .util .Arrays .copyOfRange (bitString , 1 , bitString .length );
190+ } else {
191+ // ignore unknown
192+ buffer .position (buffer .position () + len );
193+ }
194+ }
195+ } catch (BufferUnderflowException e ) {
196+ throw new IllegalArgumentException ("Invalid SEC1 Private Key: failed to parse." , e );
197+ }
198+ return result ;
199+ }
200+
201+ // Strip header/footer and base64 decode to return the DER that was encoded in the PEM
202+ public static byte [] pemToDer (String pem ) {
203+ StringBuilder sb = new StringBuilder ();
204+ for (String line : pem .split ("\\ r?\\ n" )) {
205+ if (line .startsWith ("-----" )) {
206+ continue ;
207+ }
208+ sb .append (line .trim ());
209+ }
210+ return Base64 .getDecoder ().decode (sb .toString ());
211+ }
212+
213+ /**
214+ * Read a length from a DER byte input stream. lengths may be either a single byte (short form) or multiple bytes. If the
215+ * first bit is 0, then the remaining 7 bits give the length directly (short form). If the first bit is 1, then the next 7
216+ * bits give the number of bytes to read for the length. Eg: [0x82 0x01 0xF4] means the length is 2 bytes long (0x82) and the
217+ * length is 500 (0x01F4).
218+ *
219+ * Throws BufferUnderflowException if there are insufficient bytes
220+ *
221+ * @param buffer - byte buffer to read from
222+ * @return the length
223+ */
224+ private static int readLength (ByteBuffer buffer ) {
225+ int b = buffer .get () & 0xFF ; // convert signed byte to unsigned int
226+
227+ // if the high (first) bit is 0, then the length is a single byte, return it as is.
228+ if ((b & 0x80 ) == 0 ) {
229+ return b ;
230+ }
231+ // remove the leading 1 bit, this should give the number of bytes for the length
232+ int num = b & 0x7F ;
233+ if (num == 0 ) {
234+ throw new IllegalArgumentException ("Indefinite lengths not supported" );
235+ }
236+ // limit to 4 bytes, supported keys will never have more than 4 bytes of length
237+ if (num > 4 ) {
238+ throw new IllegalArgumentException ("Too many bytes in length" );
239+ }
240+ int val = 0 ;
241+
242+ // construct the length by reading num bytes from the input byte stream.
243+ for (int i = 0 ; i < num ; i ++) {
244+ int nb = buffer .get () & 0xFF ;
245+ if (nb < 0 ) {
246+ throw new IllegalArgumentException ("Unexpected EOF in length bytes" );
247+ }
248+ val = (val << 8 ) | (nb & 0xFF );
249+ }
250+ return val ;
251+ }
252+ }
0 commit comments