diff --git a/lib/saml_idp/controller.rb b/lib/saml_idp/controller.rb index 6745115..7adb4ab 100644 --- a/lib/saml_idp/controller.rb +++ b/lib/saml_idp/controller.rb @@ -1,4 +1,5 @@ # encoding: utf-8 +require 'logger' module SamlIdp module Controller @@ -54,18 +55,25 @@ def decode_SAMLRequest(saml_request) @saml_request = zstream.inflate(Base64.decode64(saml_request)) zstream.finish zstream.close - @saml_request_id = @saml_request[/ID=['"](.+?)['"]/, 1] - @saml_acs_url = @saml_request[/AssertionConsumerServiceURL=['"](.+?)['"]/, 1] + xml_doc = Nokogiri::XML(@saml_request)do |config| + # Strict parsing; raise an error when parsing malformed documents + config.strict + end + auth_request = xml_doc.xpath('//samlp:AuthnRequest').first + @saml_request_id = auth_request['ID'] + @saml_acs_url = auth_request['AssertionConsumerServiceURL'] end def encode_SAMLResponse(nameID, opts = {}) now = Time.now.utc + encoded_saml_acs_url = @saml_acs_url.encode(xml: :attr) + response_id, reference_id = SecureRandom.uuid, SecureRandom.uuid audience_uri = opts[:audience_uri] || saml_acs_url[/^(.*?\/\/.*?\/)/, 1] issuer_uri = opts[:issuer_uri] || (defined?(request) && request.url) || "http://example.com" attributes_statement = attributes(opts[:attributes_provider], nameID) - assertion = %[#{issuer_uri}#{nameID}#{audience_uri}#{attributes_statement}urn:federation:authentication:windows] + assertion = %[#{issuer_uri}#{nameID}#{audience_uri}#{attributes_statement}urn:federation:authentication:windows] digest_value = Base64.encode64(algorithm.digest(assertion)).gsub(/\n/, '') @@ -77,7 +85,7 @@ def encode_SAMLResponse(nameID, opts = {}) assertion_and_signature = assertion.sub(/Issuer\>\#{signature}#{issuer_uri}#{assertion_and_signature}] + xml = %[#{issuer_uri}#{assertion_and_signature}] Base64.encode64(xml) end diff --git a/spec/saml_idp/controller_spec.rb b/spec/saml_idp/controller_spec.rb index cdb03da..136e0c8 100644 --- a/spec/saml_idp/controller_spec.rb +++ b/spec/saml_idp/controller_spec.rb @@ -7,12 +7,41 @@ def params @params ||= {} end + SAML_ACS_URLS = %w(https://example.com/saml/consume https://example.com/saml/consume?toto=value&tata=value2) - it "should find the SAML ACS URL" do - requested_saml_acs_url = "https://example.com/saml/consume" - params[:SAMLRequest] = make_saml_request(requested_saml_acs_url) + SAML_ACS_URLS.each do |requested_saml_acs_url| + it "should find the SAML ACS URL: #{requested_saml_acs_url}" do + params[:SAMLRequest] = make_saml_request(requested_saml_acs_url) + validate_saml_request + expect(saml_acs_url).to eq(requested_saml_acs_url) + end + end + + it 'should find the SAML ACS URL' do + xml = %q( + + + + ) + params[:SAMLRequest] = prepare_saml_request(xml) validate_saml_request - expect(saml_acs_url).to eq(requested_saml_acs_url) + expect(saml_acs_url).to eq('https://sp.example.com/SAML2/SSO/Artifact') + end + + it 'does not validate wrong requests' do + params[:SAMLRequest] = 'FAKE NEWS' + expect{validate_saml_request}.to raise_error + end + + it 'does not validate wrong xmls' do + xml = %q( + + + + ) + + params[:SAMLRequest] = prepare_saml_request(xml) + expect{validate_saml_request}.to raise_error end context "SAML Responses" do @@ -54,4 +83,20 @@ def params end end end + context "SAML Responses with special characters" do + before(:each) do + params[:SAMLRequest] = make_saml_request('https://example.com/saml/consume?toto=value&tata=value2') + validate_saml_request + end + it "should create a SAML Response" do + saml_response = encode_SAMLResponse("foo@example.com") + response = OneLogin::RubySaml::Response.new(saml_response) + expect(response.name_id).to eq("foo@example.com") + expect(response.issuer).to eq("http://example.com") + response.settings = saml_settings + expect(response.is_valid?).to be true + end + end + + end diff --git a/spec/support/saml_request_macros.rb b/spec/support/saml_request_macros.rb index fc2a6da..f8c9d87 100644 --- a/spec/support/saml_request_macros.rb +++ b/spec/support/saml_request_macros.rb @@ -16,4 +16,17 @@ def saml_settings(options = {}) settings end + def prepare_saml_request(xml) + deflated = deflate(xml) + encode(deflated) + end + + def encode(encoded) + Base64.encode64(encoded).gsub(/\n/, "") + end + + def deflate(inflated) + Zlib::Deflate.deflate(inflated, 9)[2..-5] + end + end