Skip to content

Getting response attributes from Saml2AuthenticatedPrincipal #8667

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,11 +17,13 @@

import java.security.cert.X509Certificate;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
Expand All @@ -31,7 +33,19 @@
import net.shibboleth.utilities.java.support.resolver.CriteriaSet;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.joda.time.DateTime;
import org.opensaml.core.criterion.EntityIdCriterion;
import org.opensaml.core.xml.XMLObject;
import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
import org.opensaml.core.xml.io.Marshaller;

import org.opensaml.core.xml.schema.XSAny;
import org.opensaml.core.xml.schema.XSBoolean;
import org.opensaml.core.xml.schema.XSBooleanValue;
import org.opensaml.core.xml.schema.XSDateTime;
import org.opensaml.core.xml.schema.XSInteger;
import org.opensaml.core.xml.schema.XSString;
import org.opensaml.core.xml.schema.XSURI;
import org.opensaml.saml.common.assertion.ValidationContext;
import org.opensaml.saml.common.assertion.ValidationResult;
import org.opensaml.saml.common.xml.SAMLConstants;
Expand All @@ -45,6 +59,8 @@
import org.opensaml.saml.saml2.assertion.impl.AudienceRestrictionConditionValidator;
import org.opensaml.saml.saml2.assertion.impl.BearerSubjectConfirmationValidator;
import org.opensaml.saml.saml2.core.Assertion;
import org.opensaml.saml.saml2.core.Attribute;
import org.opensaml.saml.saml2.core.AttributeStatement;
import org.opensaml.saml.saml2.core.EncryptedAssertion;
import org.opensaml.saml.saml2.core.EncryptedID;
import org.opensaml.saml.saml2.core.NameID;
Expand Down Expand Up @@ -205,8 +221,9 @@ public Authentication authenticate(Authentication authentication) throws Authent
List<Assertion> validAssertions = validateResponse(token, response);
Assertion assertion = validAssertions.get(0);
String username = getUsername(token, assertion);
Map<String, List<Object>> attributes = getAssertionAttributes(assertion);
return new Saml2Authentication(
new SimpleSaml2AuthenticatedPrincipal(username), token.getSaml2Response(),
new SimpleSaml2AuthenticatedPrincipal(username, attributes), token.getSaml2Response(),
this.authoritiesMapper.mapAuthorities(getAssertionAuthorities(assertion)));
} catch (Saml2AuthenticationException e) {
throw e;
Expand Down Expand Up @@ -494,6 +511,60 @@ private NameID decrypt(Saml2AuthenticationToken token, EncryptedID assertion)
throw last;
}

private Map<String, List<Object>> getAssertionAttributes(Assertion assertion) {
Map<String, List<Object>> attributeMap = new LinkedHashMap<>();
for (AttributeStatement attributeStatement : assertion.getAttributeStatements()) {
for (Attribute attribute : attributeStatement.getAttributes()) {

List<Object> attributeValues = new ArrayList<>();
for (XMLObject xmlObject : attribute.getAttributeValues()) {
Object attributeValue = getXmlObjectValue(xmlObject);
if (attributeValue != null) {
attributeValues.add(attributeValue);
}
}
attributeMap.put(attribute.getName(), attributeValues);

}
}
return attributeMap;
}

private Object getXmlObjectValue(XMLObject xmlObject) {
if (xmlObject == null) {
return null;
}
if (xmlObject instanceof XSAny) {
return getXSAnyObjectValue((XSAny) xmlObject);
}
if (xmlObject instanceof XSString) {
return ((XSString) xmlObject).getValue();
}
if (xmlObject instanceof XSInteger) {
return ((XSInteger) xmlObject).getValue();
}
if (xmlObject instanceof XSURI) {
return ((XSURI) xmlObject).getValue();
}
if (xmlObject instanceof XSBoolean) {
XSBooleanValue xsBooleanValue = ((XSBoolean) xmlObject).getValue();
return xsBooleanValue != null ? xsBooleanValue.getValue() : null;
}
if (xmlObject instanceof XSDateTime) {
DateTime dateTime = ((XSDateTime) xmlObject).getValue();
return dateTime != null ? Instant.ofEpochMilli(dateTime.getMillis()) : null;
}
return null;
}

private Object getXSAnyObjectValue(XSAny xsAny) {
Marshaller marshaller = XMLObjectProviderRegistrySupport.getMarshallerFactory().getMarshaller(xsAny);
if (marshaller != null) {
return this.saml.serialize(xsAny);
}
return xsAny.getTextContent();
}

private Saml2Error validationError(String code, String description) {
return new Saml2Error(code, description);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,7 +16,13 @@

package org.springframework.security.saml2.provider.service.authentication;

import org.springframework.lang.Nullable;
import org.springframework.security.core.AuthenticatedPrincipal;
import org.springframework.util.CollectionUtils;

import java.util.Collections;
import java.util.List;
import java.util.Map;

/**
* Saml2 representation of an {@link AuthenticatedPrincipal}.
Expand All @@ -25,4 +31,40 @@
* @since 5.2.2
*/
public interface Saml2AuthenticatedPrincipal extends AuthenticatedPrincipal {
/**
* Get the first value of Saml2 token attribute by name
*
* @param name the name of the attribute
* @param <A> the type of the attribute
* @return the first attribute value or {@code null} otherwise
* @since 5.4
*/
@Nullable
default <A> A getFirstAttribute(String name) {
List<A> values = getAttribute(name);
return CollectionUtils.firstElement(values);
}

/**
* Get the Saml2 token attribute by name
*
* @param name the name of the attribute
* @param <A> the type of the attribute
* @return the attribute or {@code null} otherwise
* @since 5.4
*/
@Nullable
default <A> List<A> getAttribute(String name) {
return (List<A>) getAttributes().get(name);
}

/**
* Get the Saml2 token attributes
*
* @return the Saml2 token attributes
* @since 5.4
*/
default Map<String, List<Object>> getAttributes() {
return Collections.emptyMap();
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,6 +17,8 @@
package org.springframework.security.saml2.provider.service.authentication;

import java.io.Serializable;
import java.util.List;
import java.util.Map;

/**
* Default implementation of a {@link Saml2AuthenticatedPrincipal}.
Expand All @@ -27,13 +29,20 @@
class SimpleSaml2AuthenticatedPrincipal implements Saml2AuthenticatedPrincipal, Serializable {

private final String name;
private final Map<String, List<Object>> attributes;

SimpleSaml2AuthenticatedPrincipal(String name) {
SimpleSaml2AuthenticatedPrincipal(String name, Map<String, List<Object>> attributes) {
this.name = name;
this.attributes = attributes;
}

@Override
public String getName() {
return this.name;
}

@Override
public Map<String, List<Object>> getAttributes() {
return this.attributes;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.time.Instant;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;

import org.hamcrest.BaseMatcher;
import org.hamcrest.Description;
Expand All @@ -39,6 +43,7 @@
import org.springframework.security.saml2.credentials.Saml2X509Credential;

import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.assertion;
import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.attributeStatements;
import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.encrypted;
import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.response;
import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.signed;
Expand All @@ -47,6 +52,7 @@
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartySigningCredential;
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartyDecryptingCredential;
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartyVerifyingCredential;
import static org.springframework.test.util.AssertionErrors.assertEquals;
import static org.springframework.test.util.AssertionErrors.assertTrue;
import static org.springframework.util.StringUtils.hasText;

Expand Down Expand Up @@ -193,6 +199,30 @@ public void authenticateWhenAssertionContainsValidationAddressThenItSucceeds() t
this.provider.authenticate(token);
}

@Test
public void authenticateWhenAssertionContainsAttributesThenItSucceeds() {
Response response = response();
Assertion assertion = assertion();
attributeStatements().forEach(as -> assertion.getAttributeStatements().add(as));
signed(assertion, assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID);
response.getAssertions().add(assertion);
Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential());
Authentication authentication = this.provider.authenticate(token);
Saml2AuthenticatedPrincipal principal = (Saml2AuthenticatedPrincipal) authentication.getPrincipal();

Map<String, Object> attributes = new LinkedHashMap<>();
attributes.put("email", Arrays.asList("[email protected]", "[email protected]"));
attributes.put("name", Collections.singletonList("John Doe"));
attributes.put("age", Collections.singletonList(21));
attributes.put("website", Collections.singletonList("https://johndoe.com/"));
attributes.put("registered", Collections.singletonList(true));
Instant registeredDate = Instant.ofEpochMilli(DateTime.parse("1970-01-01T00:00:00Z").getMillis());
attributes.put("registeredDate", Collections.singletonList(registeredDate));

assertEquals("Values should be equal", "John Doe", principal.getFirstAttribute("name"));
assertTrue("Attributes should be equal", attributes.equals(principal.getAttributes()));
}

@Test
public void authenticateWhenEncryptedAssertionWithoutSignatureThenItFails() throws Exception {
this.exception.expect(authenticationMatcher(Saml2ErrorCodes.INVALID_SIGNATURE));
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,15 +16,58 @@

package org.springframework.security.saml2.provider.service.authentication;

import org.junit.Assert;
import org.joda.time.DateTime;
import org.junit.Test;

import java.time.Instant;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

import static org.assertj.core.api.Assertions.assertThat;

public class SimpleSaml2AuthenticatedPrincipalTests {

@Test
public void createSimpleSaml2AuthenticatedPrincipal() {
SimpleSaml2AuthenticatedPrincipal principal = new SimpleSaml2AuthenticatedPrincipal("user");
Map<String, List<Object>> attributes = new LinkedHashMap<>();
attributes.put("email", Arrays.asList("[email protected]", "[email protected]"));
SimpleSaml2AuthenticatedPrincipal principal = new SimpleSaml2AuthenticatedPrincipal("user", attributes);
assertThat(principal.getName()).isEqualTo("user");
assertThat(principal.getAttributes()).isEqualTo(attributes);
}

@Test
public void getFirstAttributeWhenStringValueThenReturnsValue() {
Map<String, List<Object>> attributes = new LinkedHashMap<>();
attributes.put("email", Arrays.asList("[email protected]", "[email protected]"));
SimpleSaml2AuthenticatedPrincipal principal = new SimpleSaml2AuthenticatedPrincipal("user", attributes);
assertThat(principal.<String>getFirstAttribute("email")).isEqualTo(attributes.get("email").get(0));
}

@Test
public void getAttributeWhenStringValuesThenReturnsValues() {
Map<String, List<Object>> attributes = new LinkedHashMap<>();
attributes.put("email", Arrays.asList("[email protected]", "[email protected]"));
SimpleSaml2AuthenticatedPrincipal principal = new SimpleSaml2AuthenticatedPrincipal("user", attributes);
assertThat(principal.<String>getAttribute("email")).isEqualTo(attributes.get("email"));
}

@Test
public void getAttributeWhenDistinctValuesThenReturnsValues() {
final Boolean registered = true;
final Instant registeredDate = Instant.ofEpochMilli(DateTime.parse("1970-01-01T00:00:00Z").getMillis());

Map<String, List<Object>> attributes = new LinkedHashMap<>();
attributes.put("registration", Arrays.asList(registered, registeredDate));

SimpleSaml2AuthenticatedPrincipal principal = new SimpleSaml2AuthenticatedPrincipal("user", attributes);

List<Object> registrationInfo = principal.getAttribute("registration");

Assert.assertEquals("user", principal.getName());
assertThat(registrationInfo).isNotNull();
assertThat((Boolean) registrationInfo.get(0)).isEqualTo(registered);
assertThat((Instant) registrationInfo.get(1)).isEqualTo(registeredDate);
}
}
Loading