Skip to content

Commit cc7398d

Browse files
Add Support JdbcUserCredentialRepository
Closes spring-projectsgh-16224
1 parent 5a81a1f commit cc7398d

File tree

7 files changed

+636
-1
lines changed

7 files changed

+636
-1
lines changed
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Copyright 2002-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.security.web.aot.hint;
18+
19+
import org.springframework.aot.hint.RuntimeHints;
20+
import org.springframework.aot.hint.RuntimeHintsRegistrar;
21+
import org.springframework.jdbc.core.JdbcOperations;
22+
import org.springframework.security.web.webauthn.api.CredentialRecord;
23+
import org.springframework.security.web.webauthn.management.UserCredentialRepository;
24+
25+
/**
26+
*
27+
* A JDBC implementation of an {@link UserCredentialRepository} that uses a
28+
* {@link JdbcOperations} for {@link CredentialRecord} persistence.
29+
*
30+
* @author Max Batischev
31+
* @since 6.5
32+
*/
33+
class UserCredentialRuntimeHints implements RuntimeHintsRegistrar {
34+
35+
@Override
36+
public void registerHints(RuntimeHints hints, ClassLoader classLoader) {
37+
hints.resources().registerPattern("org/springframework/security/user-credentials-schema.sql");
38+
}
39+
40+
}
Lines changed: 315 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,315 @@
1+
/*
2+
* Copyright 2002-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.security.web.webauthn.management;
18+
19+
import java.sql.PreparedStatement;
20+
import java.sql.ResultSet;
21+
import java.sql.SQLException;
22+
import java.sql.Timestamp;
23+
import java.sql.Types;
24+
import java.time.Instant;
25+
import java.util.ArrayList;
26+
import java.util.HashSet;
27+
import java.util.List;
28+
import java.util.Set;
29+
import java.util.function.Function;
30+
31+
import org.springframework.jdbc.core.ArgumentPreparedStatementSetter;
32+
import org.springframework.jdbc.core.JdbcOperations;
33+
import org.springframework.jdbc.core.PreparedStatementSetter;
34+
import org.springframework.jdbc.core.RowMapper;
35+
import org.springframework.jdbc.core.SqlParameterValue;
36+
import org.springframework.jdbc.support.lob.DefaultLobHandler;
37+
import org.springframework.jdbc.support.lob.LobCreator;
38+
import org.springframework.jdbc.support.lob.LobHandler;
39+
import org.springframework.security.web.webauthn.api.AuthenticatorTransport;
40+
import org.springframework.security.web.webauthn.api.Bytes;
41+
import org.springframework.security.web.webauthn.api.CredentialRecord;
42+
import org.springframework.security.web.webauthn.api.ImmutableCredentialRecord;
43+
import org.springframework.security.web.webauthn.api.ImmutablePublicKeyCose;
44+
import org.springframework.security.web.webauthn.api.PublicKeyCredentialType;
45+
import org.springframework.util.Assert;
46+
import org.springframework.util.CollectionUtils;
47+
48+
/**
49+
* A JDBC implementation of an {@link UserCredentialRepository} that uses a
50+
* {@link JdbcOperations} for {@link CredentialRecord} persistence.
51+
*
52+
* <b>NOTE:</b> This {@code UserCredentialRepository} depends on the table definition
53+
* described in "classpath:org/springframework/security/user-credentials-schema.sql" and
54+
* therefore MUST be defined in the database schema.
55+
*
56+
* @author Max Batischev
57+
* @since 6.5
58+
* @see UserCredentialRepository
59+
* @see CredentialRecord
60+
* @see JdbcOperations
61+
* @see RowMapper
62+
*/
63+
public final class JdbcUserCredentialRepository implements UserCredentialRepository {
64+
65+
private RowMapper<CredentialRecord> credentialRecordRowMapper = new CredentialRecordRowMapper();
66+
67+
private Function<CredentialRecord, List<SqlParameterValue>> credentialRecordParametersMapper = new CredentialRecordParametersMapper();
68+
69+
private LobHandler lobHandler = new DefaultLobHandler();
70+
71+
private final JdbcOperations jdbcOperations;
72+
73+
private static final String TABLE_NAME = "user_credentials";
74+
75+
// @formatter:off
76+
private static final String COLUMN_NAMES = "credential_id, "
77+
+ "user_entity_user_id, "
78+
+ "public_key, "
79+
+ "signature_count, "
80+
+ "uv_initialized, "
81+
+ "backup_eligible, "
82+
+ "authenticator_transports, "
83+
+ "public_key_credential_type, "
84+
+ "backup_state, "
85+
+ "attestation_object, "
86+
+ "attestation_client_data_json, "
87+
+ "created, "
88+
+ "last_used, "
89+
+ "label ";
90+
// @formatter:on
91+
92+
// @formatter:off
93+
private static final String AUTHENTICATOR_TRANSPORT_COLUMN_NAMES = "value, "
94+
+ "credential_id ";
95+
// @formatter:on
96+
97+
// @formatter:off
98+
private static final String SAVE_CREDENTIAL_RECORD_SQL = "INSERT INTO " + TABLE_NAME
99+
+ " (" + COLUMN_NAMES + ") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)";
100+
// @formatter:on
101+
102+
private static final String ID_FILTER = "credential_id = ? ";
103+
104+
private static final String USER_ID_FILTER = "user_entity_user_id = ? ";
105+
106+
// @formatter:off
107+
private static final String FIND_CREDENTIAL_RECORD_BY_ID_SQL = "SELECT " + COLUMN_NAMES
108+
+ " FROM " + TABLE_NAME
109+
+ " WHERE " + ID_FILTER;
110+
// @formatter:on
111+
112+
// @formatter:off
113+
private static final String FIND_CREDENTIAL_RECORD_BY_USER_ID_SQL = "SELECT " + COLUMN_NAMES
114+
+ " FROM " + TABLE_NAME
115+
+ " WHERE " + USER_ID_FILTER;
116+
// @formatter:on
117+
118+
private static final String DELETE_CREDENTIAL_RECORD_SQL = "DELETE FROM " + TABLE_NAME + " WHERE " + ID_FILTER;
119+
120+
/**
121+
* Constructs a {@code JdbcUserCredentialRepository} using the provided parameters.
122+
* @param jdbcOperations the JDBC operations
123+
*/
124+
public JdbcUserCredentialRepository(JdbcOperations jdbcOperations) {
125+
Assert.notNull(jdbcOperations, "jdbcOperations cannot be null");
126+
this.jdbcOperations = jdbcOperations;
127+
}
128+
129+
@Override
130+
public void delete(Bytes credentialId) {
131+
Assert.notNull(credentialId, "credentialId cannot be null");
132+
SqlParameterValue[] parameters = new SqlParameterValue[] {
133+
new SqlParameterValue(Types.VARCHAR, credentialId.toBase64UrlString()), };
134+
PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters);
135+
this.jdbcOperations.update(DELETE_CREDENTIAL_RECORD_SQL, pss);
136+
}
137+
138+
@Override
139+
public void save(CredentialRecord record) {
140+
Assert.notNull(record, "record cannot be null");
141+
List<SqlParameterValue> parameters = this.credentialRecordParametersMapper.apply(record);
142+
try (LobCreator lobCreator = this.lobHandler.getLobCreator()) {
143+
PreparedStatementSetter pss = new LobCreatorArgumentPreparedStatementSetter(lobCreator,
144+
parameters.toArray());
145+
this.jdbcOperations.update(SAVE_CREDENTIAL_RECORD_SQL, pss);
146+
}
147+
}
148+
149+
@Override
150+
public CredentialRecord findByCredentialId(Bytes credentialId) {
151+
Assert.notNull(credentialId, "credentialId cannot be null");
152+
SqlParameterValue[] parameters = new SqlParameterValue[] {
153+
new SqlParameterValue(Types.VARCHAR, credentialId.toBase64UrlString()) };
154+
PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters);
155+
List<CredentialRecord> result = this.jdbcOperations.query(FIND_CREDENTIAL_RECORD_BY_ID_SQL, pss,
156+
this.credentialRecordRowMapper);
157+
return !result.isEmpty() ? result.get(0) : null;
158+
}
159+
160+
@Override
161+
public List<CredentialRecord> findByUserId(Bytes userId) {
162+
Assert.notNull(userId, "userId cannot be null");
163+
SqlParameterValue[] parameters = new SqlParameterValue[] {
164+
new SqlParameterValue(Types.VARCHAR, userId.toBase64UrlString()) };
165+
PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters);
166+
return this.jdbcOperations.query(FIND_CREDENTIAL_RECORD_BY_USER_ID_SQL, pss, this.credentialRecordRowMapper);
167+
}
168+
169+
/**
170+
* Sets a {@link LobHandler} for large binary fields and large text field parameters.
171+
* @param lobHandler the lob handler
172+
*/
173+
public void setLobHandler(LobHandler lobHandler) {
174+
Assert.notNull(lobHandler, "lobHandler cannot be null");
175+
this.lobHandler = lobHandler;
176+
}
177+
178+
private static class CredentialRecordParametersMapper
179+
implements Function<CredentialRecord, List<SqlParameterValue>> {
180+
181+
@Override
182+
public List<SqlParameterValue> apply(CredentialRecord record) {
183+
List<SqlParameterValue> parameters = new ArrayList<>();
184+
185+
List<String> transports = new ArrayList<>();
186+
if (!CollectionUtils.isEmpty(record.getTransports())) {
187+
for (AuthenticatorTransport transport : record.getTransports()) {
188+
transports.add(transport.getValue());
189+
}
190+
}
191+
192+
parameters.add(new SqlParameterValue(Types.VARCHAR, record.getCredentialId().toBase64UrlString()));
193+
parameters.add(new SqlParameterValue(Types.VARCHAR, record.getUserEntityUserId().toBase64UrlString()));
194+
parameters.add(new SqlParameterValue(Types.BLOB, record.getPublicKey().getBytes()));
195+
parameters.add(new SqlParameterValue(Types.BIGINT, record.getSignatureCount()));
196+
parameters.add(new SqlParameterValue(Types.BOOLEAN, record.isUvInitialized()));
197+
parameters.add(new SqlParameterValue(Types.BOOLEAN, record.isBackupEligible()));
198+
parameters.add(new SqlParameterValue(Types.VARCHAR,
199+
(!CollectionUtils.isEmpty(record.getTransports())) ? String.join(",", transports) : ""));
200+
parameters.add(new SqlParameterValue(Types.VARCHAR,
201+
(record.getCredentialType() != null) ? record.getCredentialType().getValue() : null));
202+
parameters.add(new SqlParameterValue(Types.BOOLEAN, record.isBackupState()));
203+
parameters.add(new SqlParameterValue(Types.BLOB,
204+
(record.getAttestationObject() != null) ? record.getAttestationObject().getBytes() : null));
205+
parameters.add(new SqlParameterValue(Types.BLOB, (record.getAttestationClientDataJSON() != null)
206+
? record.getAttestationClientDataJSON().getBytes() : null));
207+
parameters.add(new SqlParameterValue(Types.TIMESTAMP, fromInstant(record.getCreated())));
208+
parameters.add(new SqlParameterValue(Types.TIMESTAMP, fromInstant(record.getLastUsed())));
209+
parameters.add(new SqlParameterValue(Types.VARCHAR, record.getLabel()));
210+
211+
return parameters;
212+
}
213+
214+
private Timestamp fromInstant(Instant instant) {
215+
if (instant == null) {
216+
return null;
217+
}
218+
return Timestamp.from(instant);
219+
}
220+
221+
}
222+
223+
private static final class LobCreatorArgumentPreparedStatementSetter extends ArgumentPreparedStatementSetter {
224+
225+
private final LobCreator lobCreator;
226+
227+
private LobCreatorArgumentPreparedStatementSetter(LobCreator lobCreator, Object[] args) {
228+
super(args);
229+
this.lobCreator = lobCreator;
230+
}
231+
232+
@Override
233+
protected void doSetValue(PreparedStatement ps, int parameterPosition, Object argValue) throws SQLException {
234+
if (argValue instanceof SqlParameterValue paramValue) {
235+
if (paramValue.getSqlType() == Types.BLOB) {
236+
if (paramValue.getValue() != null) {
237+
Assert.isInstanceOf(byte[].class, paramValue.getValue(),
238+
"Value of blob parameter must be byte[]");
239+
}
240+
byte[] valueBytes = (byte[]) paramValue.getValue();
241+
this.lobCreator.setBlobAsBytes(ps, parameterPosition, valueBytes);
242+
return;
243+
}
244+
}
245+
super.doSetValue(ps, parameterPosition, argValue);
246+
}
247+
248+
}
249+
250+
private static class CredentialRecordRowMapper implements RowMapper<CredentialRecord> {
251+
252+
private LobHandler lobHandler = new DefaultLobHandler();
253+
254+
@Override
255+
public CredentialRecord mapRow(ResultSet rs, int rowNum) throws SQLException {
256+
Bytes credentialId = Bytes.fromBase64(new String(rs.getString("credential_id").getBytes()));
257+
Bytes userEntityUserId = Bytes.fromBase64(new String(rs.getString("user_entity_user_id").getBytes()));
258+
ImmutablePublicKeyCose publicKey = new ImmutablePublicKeyCose(
259+
this.lobHandler.getBlobAsBytes(rs, "public_key"));
260+
long signatureCount = rs.getLong("signature_count");
261+
boolean uvInitialized = rs.getBoolean("uv_initialized");
262+
boolean backupEligible = rs.getBoolean("backup_eligible");
263+
PublicKeyCredentialType credentialType = PublicKeyCredentialType
264+
.valueOf(rs.getString("public_key_credential_type"));
265+
boolean backupState = rs.getBoolean("backup_state");
266+
267+
Bytes attestationObject = null;
268+
byte[] rawAttestationObject = this.lobHandler.getBlobAsBytes(rs, "attestation_object");
269+
if (rawAttestationObject != null) {
270+
attestationObject = new Bytes(rawAttestationObject);
271+
}
272+
273+
Bytes attestationClientDataJson = null;
274+
byte[] rawAttestationClientDataJson = this.lobHandler.getBlobAsBytes(rs, "attestation_client_data_json");
275+
if (rawAttestationClientDataJson != null) {
276+
attestationClientDataJson = new Bytes(rawAttestationClientDataJson);
277+
}
278+
279+
Instant created = fromTimestamp(rs.getTimestamp("created"));
280+
Instant lastUsed = fromTimestamp(rs.getTimestamp("last_used"));
281+
String label = rs.getString("label");
282+
String[] transports = rs.getString("authenticator_transports").split(",");
283+
284+
Set<AuthenticatorTransport> authenticatorTransports = new HashSet<>();
285+
for (String transport : transports) {
286+
authenticatorTransports.add(AuthenticatorTransport.valueOf(transport));
287+
}
288+
return ImmutableCredentialRecord.builder()
289+
.credentialId(credentialId)
290+
.userEntityUserId(userEntityUserId)
291+
.publicKey(publicKey)
292+
.signatureCount(signatureCount)
293+
.uvInitialized(uvInitialized)
294+
.backupEligible(backupEligible)
295+
.credentialType(credentialType)
296+
.backupState(backupState)
297+
.attestationObject(attestationObject)
298+
.attestationClientDataJSON(attestationClientDataJson)
299+
.created(created)
300+
.label(label)
301+
.lastUsed(lastUsed)
302+
.transports(authenticatorTransports)
303+
.build();
304+
}
305+
306+
private Instant fromTimestamp(Timestamp timestamp) {
307+
if (timestamp == null) {
308+
return null;
309+
}
310+
return timestamp.toInstant();
311+
}
312+
313+
}
314+
315+
}
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
org.springframework.aot.hint.RuntimeHintsRegistrar=\
2-
org.springframework.security.web.aot.hint.WebMvcSecurityRuntimeHints
2+
org.springframework.security.web.aot.hint.WebMvcSecurityRuntimeHints,\
3+
org.springframework.security.web.aot.hint.UserCredentialRuntimeHints
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
create table user_credentials
2+
(
3+
credential_id varchar(1000) not null,
4+
user_entity_user_id varchar(1000) not null,
5+
public_key blob not null,
6+
signature_count bigint,
7+
uv_initialized boolean,
8+
backup_eligible boolean not null,
9+
authenticator_transports varchar(1000),
10+
public_key_credential_type varchar(100),
11+
backup_state boolean not null,
12+
attestation_object blob,
13+
attestation_client_data_json blob,
14+
created timestamp,
15+
last_used timestamp,
16+
label varchar(1000) not null,
17+
primary key (credential_id)
18+
);

0 commit comments

Comments
 (0)