diff --git a/lib/saml_idp/controller.rb b/lib/saml_idp/controller.rb
index 6745115..7adb4ab 100644
--- a/lib/saml_idp/controller.rb
+++ b/lib/saml_idp/controller.rb
@@ -1,4 +1,5 @@
# encoding: utf-8
+require 'logger'
module SamlIdp
module Controller
@@ -54,18 +55,25 @@ def decode_SAMLRequest(saml_request)
@saml_request = zstream.inflate(Base64.decode64(saml_request))
zstream.finish
zstream.close
- @saml_request_id = @saml_request[/ID=['"](.+?)['"]/, 1]
- @saml_acs_url = @saml_request[/AssertionConsumerServiceURL=['"](.+?)['"]/, 1]
+ xml_doc = Nokogiri::XML(@saml_request)do |config|
+ # Strict parsing; raise an error when parsing malformed documents
+ config.strict
+ end
+ auth_request = xml_doc.xpath('//samlp:AuthnRequest').first
+ @saml_request_id = auth_request['ID']
+ @saml_acs_url = auth_request['AssertionConsumerServiceURL']
end
def encode_SAMLResponse(nameID, opts = {})
now = Time.now.utc
+ encoded_saml_acs_url = @saml_acs_url.encode(xml: :attr)
+
response_id, reference_id = SecureRandom.uuid, SecureRandom.uuid
audience_uri = opts[:audience_uri] || saml_acs_url[/^(.*?\/\/.*?\/)/, 1]
issuer_uri = opts[:issuer_uri] || (defined?(request) && request.url) || "http://example.com"
attributes_statement = attributes(opts[:attributes_provider], nameID)
- assertion = %[#{issuer_uri}#{nameID}#{audience_uri}#{attributes_statement}urn:federation:authentication:windows]
+ assertion = %[#{issuer_uri}#{nameID}#{audience_uri}#{attributes_statement}urn:federation:authentication:windows]
digest_value = Base64.encode64(algorithm.digest(assertion)).gsub(/\n/, '')
@@ -77,7 +85,7 @@ def encode_SAMLResponse(nameID, opts = {})
assertion_and_signature = assertion.sub(/Issuer\>\#{signature}#{issuer_uri}#{assertion_and_signature}]
+ xml = %[#{issuer_uri}#{assertion_and_signature}]
Base64.encode64(xml)
end
diff --git a/spec/saml_idp/controller_spec.rb b/spec/saml_idp/controller_spec.rb
index cdb03da..136e0c8 100644
--- a/spec/saml_idp/controller_spec.rb
+++ b/spec/saml_idp/controller_spec.rb
@@ -7,12 +7,41 @@
def params
@params ||= {}
end
+ SAML_ACS_URLS = %w(https://example.com/saml/consume https://example.com/saml/consume?toto=value&tata=value2)
- it "should find the SAML ACS URL" do
- requested_saml_acs_url = "https://example.com/saml/consume"
- params[:SAMLRequest] = make_saml_request(requested_saml_acs_url)
+ SAML_ACS_URLS.each do |requested_saml_acs_url|
+ it "should find the SAML ACS URL: #{requested_saml_acs_url}" do
+ params[:SAMLRequest] = make_saml_request(requested_saml_acs_url)
+ validate_saml_request
+ expect(saml_acs_url).to eq(requested_saml_acs_url)
+ end
+ end
+
+ it 'should find the SAML ACS URL' do
+ xml = %q(
+
+
+
+ )
+ params[:SAMLRequest] = prepare_saml_request(xml)
validate_saml_request
- expect(saml_acs_url).to eq(requested_saml_acs_url)
+ expect(saml_acs_url).to eq('https://sp.example.com/SAML2/SSO/Artifact')
+ end
+
+ it 'does not validate wrong requests' do
+ params[:SAMLRequest] = 'FAKE NEWS'
+ expect{validate_saml_request}.to raise_error
+ end
+
+ it 'does not validate wrong xmls' do
+ xml = %q(
+
+
+
+ )
+
+ params[:SAMLRequest] = prepare_saml_request(xml)
+ expect{validate_saml_request}.to raise_error
end
context "SAML Responses" do
@@ -54,4 +83,20 @@ def params
end
end
end
+ context "SAML Responses with special characters" do
+ before(:each) do
+ params[:SAMLRequest] = make_saml_request('https://example.com/saml/consume?toto=value&tata=value2')
+ validate_saml_request
+ end
+ it "should create a SAML Response" do
+ saml_response = encode_SAMLResponse("foo@example.com")
+ response = OneLogin::RubySaml::Response.new(saml_response)
+ expect(response.name_id).to eq("foo@example.com")
+ expect(response.issuer).to eq("http://example.com")
+ response.settings = saml_settings
+ expect(response.is_valid?).to be true
+ end
+ end
+
+
end
diff --git a/spec/support/saml_request_macros.rb b/spec/support/saml_request_macros.rb
index fc2a6da..f8c9d87 100644
--- a/spec/support/saml_request_macros.rb
+++ b/spec/support/saml_request_macros.rb
@@ -16,4 +16,17 @@ def saml_settings(options = {})
settings
end
+ def prepare_saml_request(xml)
+ deflated = deflate(xml)
+ encode(deflated)
+ end
+
+ def encode(encoded)
+ Base64.encode64(encoded).gsub(/\n/, "")
+ end
+
+ def deflate(inflated)
+ Zlib::Deflate.deflate(inflated, 9)[2..-5]
+ end
+
end