Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ dependencies {
testImplementation group: 'org.scalamock', name: 'scalamock_3', version: '6.0.0'
testImplementation group: 'org.scalatestplus', name: 'scalacheck-1-16_3', version: '3.2.14.0'
testImplementation group: 'org.scalatest', name: 'scalatest_3', version: '3.2.16'
testImplementation group: 'com.dimafeng', name: 'testcontainers-scala-core_3', version: '0.41.4'
testImplementation group: 'com.dimafeng', name: 'testcontainers-scala-core_3', version: '0.44.0'

constraints {
api group: 'io.netty', name: 'netty-codec-http2', version: '4.1.126.Final'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,12 @@ object ImpersonationWarning {
implicit val jwtAuthRule: ImpersonationWarningExtractor[JwtAuthRule] = ImpersonationWarningExtractor[JwtAuthRule] { (rule, blockName, _) =>
Some(impersonationNotSupportedWarning(rule, blockName))
}
implicit val jwtAuthenticationRule: ImpersonationWarningExtractor[JwtAuthenticationRule] = ImpersonationWarningExtractor[JwtAuthenticationRule] { (rule, blockName, _) =>
Some(impersonationNotSupportedWarning(rule, blockName))
}
implicit val jwtAuthorizationRule: ImpersonationWarningExtractor[JwtAuthorizationRule] = ImpersonationWarningExtractor[JwtAuthorizationRule] { (rule, blockName, _) =>
Some(impersonationNotSupportedWarning(rule, blockName))
}
implicit val ldapAuthenticationRule: ImpersonationWarningExtractor[LdapAuthenticationRule] = ImpersonationWarningExtractor[LdapAuthenticationRule] { (rule, blockName, requestId) =>
ldapWarning(rule.name, blockName, rule.settings.ldap.id, rule.impersonation)(requestId)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ object RuleOrdering {
// then we could check potentially slow async rules
classOf[LdapAuthRule],
classOf[LdapAuthenticationRule],
classOf[JwtAuthenticationRule],
classOf[RorKbnAuthenticationRule],
classOf[ExternalAuthenticationRule],
classOf[AnyOfGroupsRule],
Expand All @@ -74,6 +75,7 @@ object RuleOrdering {
classOf[CombinedLogicGroupsRule],
// all authorization rules should be placed after any authentication rule
classOf[LdapAuthorizationRule],
classOf[JwtAuthorizationRule],
classOf[RorKbnAuthorizationRule],
classOf[ExternalAuthorizationRule],
// Inspection rules next; these act based on properties of the request.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,16 @@ import tech.beshu.ror.accesscontrol.factory.decoders.definitions.Definitions.Ite

import java.security.PublicKey

final case class JwtDef(override val id: Name,
authorizationTokenDef: AuthorizationTokenDef,
checkMethod: SignatureCheckMethod,
userClaim: Option[Jwt.ClaimName],
groupsConfig: Option[GroupsConfig])
extends Item {

sealed trait JwtDef extends Item {
override type Id = Name
override val idShow: Show[Name] = Show.show(_.value.value)

override def id: Name

def authorizationTokenDef: AuthorizationTokenDef
def checkMethod: SignatureCheckMethod
}

object JwtDef {
final case class Name(value: NonEmptyString)

Expand All @@ -48,4 +48,31 @@ object JwtDef {
final case class GroupsConfig(idsClaim: Jwt.ClaimName, namesClaim: Option[Jwt.ClaimName])

implicit val nameEq: Eq[Name] = Eq.fromUniversalEquals
}
}

trait JwtDefForAuthentication extends JwtDef {
def userClaim: Jwt.ClaimName
}

trait JwtDefForAuthorization extends JwtDef {
def groupsConfig: GroupsConfig
}

trait JwtDefForAuth extends JwtDefForAuthentication with JwtDefForAuthorization


final case class AuthenticationJwtDef(override val id: Name,
authorizationTokenDef: AuthorizationTokenDef,
checkMethod: SignatureCheckMethod,
userClaim: Jwt.ClaimName) extends JwtDefForAuthentication

final case class AuthorizationJwtDef(override val id: Name,
authorizationTokenDef: AuthorizationTokenDef,
checkMethod: SignatureCheckMethod,
groupsConfig: GroupsConfig) extends JwtDefForAuthorization

final case class AuthJwtDef(override val id: Name,
authorizationTokenDef: AuthorizationTokenDef,
checkMethod: SignatureCheckMethod,
userClaim: Jwt.ClaimName,
groupsConfig: GroupsConfig) extends JwtDefForAuth
Original file line number Diff line number Diff line change
Expand Up @@ -16,246 +16,27 @@
*/
package tech.beshu.ror.accesscontrol.blocks.rules.auth

import io.jsonwebtoken.Jwts
import io.jsonwebtoken.security.Keys
import monix.eval.Task
import org.apache.logging.log4j.scala.Logging
import tech.beshu.ror.accesscontrol.blocks.definitions.JwtDef
import tech.beshu.ror.accesscontrol.blocks.definitions.JwtDef.SignatureCheckMethod.*
import tech.beshu.ror.accesscontrol.blocks.rules.Rule
import tech.beshu.ror.accesscontrol.blocks.rules.Rule.AuthenticationRule.EligibleUsersSupport
import tech.beshu.ror.accesscontrol.blocks.rules.Rule.RuleResult.{Fulfilled, Rejected}
import tech.beshu.ror.accesscontrol.blocks.rules.Rule.{AuthRule, RuleName, RuleResult}
import tech.beshu.ror.accesscontrol.blocks.rules.auth.JwtAuthRule.Groups
import tech.beshu.ror.accesscontrol.blocks.rules.auth.base.impersonation.{AuthenticationImpersonationCustomSupport, AuthorizationImpersonationCustomSupport}
import tech.beshu.ror.accesscontrol.blocks.{BlockContext, BlockContextUpdater}
import tech.beshu.ror.accesscontrol.domain.LoggedUser.DirectlyLoggedUser
import tech.beshu.ror.accesscontrol.blocks.rules.Rule.RuleName
import tech.beshu.ror.accesscontrol.blocks.rules.auth.base.BaseComposedAuthenticationAndAuthorizationRule
import tech.beshu.ror.accesscontrol.domain.*
import tech.beshu.ror.accesscontrol.request.RequestContext
import tech.beshu.ror.accesscontrol.request.RequestContextOps.*
import tech.beshu.ror.accesscontrol.utils.ClaimsOps.*
import tech.beshu.ror.accesscontrol.utils.ClaimsOps.ClaimSearchResult.{Found, NotFound}
import tech.beshu.ror.implicits.*
import tech.beshu.ror.utils.RefinedUtils.*
import tech.beshu.ror.utils.uniquelist.{UniqueList, UniqueNonEmptyList}

import scala.util.Try

final class JwtAuthRule(val settings: JwtAuthRule.Settings,
override val userIdCaseSensitivity: CaseSensitivity)
extends AuthRule
with AuthenticationImpersonationCustomSupport
with AuthorizationImpersonationCustomSupport
with Logging {
final class JwtAuthRule(val authentication: JwtAuthenticationRule,
val authorization: JwtAuthorizationRule)
extends BaseComposedAuthenticationAndAuthorizationRule(
authenticationRule = authentication.withDisabledCallsToExternalAuthenticationService,
authorizationRule = authorization
) {

override val name: Rule.Name = JwtAuthRule.Name.name

override val eligibleUsers: EligibleUsersSupport = EligibleUsersSupport.NotAvailable

private val parser =
settings.jwt.checkMethod match {
case NoCheck(_) => Jwts.parser().unsecured().build()
case Hmac(rawKey) => Jwts.parser().verifyWith(Keys.hmacShaKeyFor(rawKey)).build()
case Rsa(pubKey) => Jwts.parser().verifyWith(pubKey).build()
case Ec(pubKey) => Jwts.parser().verifyWith(pubKey).build()
}

override protected[rules] def authenticate[B <: BlockContext : BlockContextUpdater](blockContext: B): Task[RuleResult[B]] =
Task.now(RuleResult.Fulfilled(blockContext))

override protected[rules] def authorize[B <: BlockContext : BlockContextUpdater](blockContext: B): Task[RuleResult[B]] =
Task
.unit
.flatMap { _ =>
settings.permittedGroups match {
case Groups.NotDefined =>
authorizeUsingJwtToken(blockContext)
case Groups.Defined(groupsLogic) if blockContext.isCurrentGroupPotentiallyEligible(groupsLogic) =>
authorizeUsingJwtToken(blockContext)
case Groups.Defined(_) =>
Task.now(RuleResult.Rejected())
}
}

private def authorizeUsingJwtToken[B <: BlockContext : BlockContextUpdater](blockContext: B): Task[RuleResult[B]] = {
jwtTokenFrom(blockContext.requestContext) match {
case None =>
logger.debug(s"[${blockContext.requestContext.id.show}] Authorization header '${settings.jwt.authorizationTokenDef.headerName.show}' is missing or does not contain a JWT token")
Task.now(Rejected())
case Some(token) =>
process(token, blockContext)
}
}

private def jwtTokenFrom(requestContext: RequestContext) = {
requestContext
.authorizationToken(settings.jwt.authorizationTokenDef)
.map(t => Jwt.Token(t.value))
}

private def process[B <: BlockContext : BlockContextUpdater](token: Jwt.Token,
blockContext: B): Task[RuleResult[B]] = {
implicit val requestId: RequestId = blockContext.requestContext.id.toRequestId
userAndGroupsFromJwtToken(token) match {
case Left(_) =>
Task.now(Rejected())
case Right((tokenPayload, user, groups)) =>
if (logger.delegate.isDebugEnabled) {
logClaimSearchResults(user, groups)(blockContext.requestContext.id.toRequestId)
}
val claimProcessingResult = for {
newBlockContext <- handleUserClaimSearchResult(blockContext, user)
finalBlockContext <- handleGroupsClaimSearchResult(newBlockContext, groups)
} yield finalBlockContext.withUserMetadata(_.withJwtToken(tokenPayload))
claimProcessingResult match {
case Left(_) =>
Task.now(Rejected())
case Right(modifiedBlockContext) =>
settings.jwt.checkMethod match {
case NoCheck(service) =>
implicit val requestId: RequestId = blockContext.requestContext.id.toRequestId
service
.authenticate(Credentials(User.Id(nes("jwt")), PlainTextSecret(token.value)))
.map(RuleResult.resultBasedOnCondition(modifiedBlockContext)(_))
case Hmac(_) | Rsa(_) | Ec(_) =>
Task.now(Fulfilled(modifiedBlockContext))
}
}
}
}

private def logClaimSearchResults(user: Option[ClaimSearchResult[User.Id]],
groups: Option[ClaimSearchResult[UniqueList[Group]]])
(implicit requestId: RequestId): Unit = {
(settings.jwt.userClaim, user) match {
case (Some(userClaim), Some(u)) =>
logger.debug(s"[${requestId.show}] JWT resolved user for claim ${userClaim.name.rawPath}: ${u.show}")
case _ =>
}
(settings.jwt.groupsConfig, groups) match {
case (Some(groupsConfig), Some(g)) =>
val claimsDescription = groupsConfig.namesClaim match {
case Some(namesClaim) => s"claims (id:'${groupsConfig.idsClaim.name.show}',name:'${namesClaim.name.show}')"
case None => s"claim '${groupsConfig.idsClaim.name.show}'"
}
logger.debug(s"[${requestId.show}] JWT resolved groups for ${claimsDescription.show}: ${g.show}")
case _ =>
}
}

private def userAndGroupsFromJwtToken(token: Jwt.Token)
(implicit requestId: RequestId) = {
claimsFrom(token).map { decodedJwtToken =>
(decodedJwtToken, userIdFrom(decodedJwtToken), groupsFrom(decodedJwtToken))
}
}

private def logBadToken(ex: Throwable, token: Jwt.Token)
(implicit requestId: RequestId): Unit = {
val tokenParts = token.show.split("\\.")
val printableToken = if (!logger.delegate.isDebugEnabled && tokenParts.length === 3) {
// signed JWT, last block is the cryptographic digest, which should be treated as a secret.
s"${tokenParts(0)}.${tokenParts(1)} (omitted digest)"
}
else {
token.show
}
logger.debug(s"[${requestId.show}] JWT token '${printableToken.show}' parsing error: ${ex.getClass.getSimpleName.show} ${ex.getMessage.show}")
}

private def claimsFrom(token: Jwt.Token)
(implicit requestId: RequestId) = {
settings.jwt.checkMethod match {
case NoCheck(_) =>
token.value.value.split("\\.").toList match {
case fst :: snd :: _ =>
Try(parser.parseUnsecuredClaims(s"$fst.$snd.").getPayload)
.toEither
.map(Jwt.Payload.apply)
.left.map { ex => logBadToken(ex, token) }
case _ =>
Left(())
}
case Hmac(_) | Rsa(_) | Ec(_) =>
Try(parser.parseSignedClaims(token.value.value).getPayload)
.toEither
.map(Jwt.Payload.apply)
.left.map { ex => logBadToken(ex, token) }
}
}

private def userIdFrom(payload: Jwt.Payload) = {
settings.jwt.userClaim.map(payload.claims.userIdClaim)
}

private def groupsFrom(payload: Jwt.Payload) = {
settings.jwt.groupsConfig.map(groupsConfig =>
payload.claims.groupsClaim(groupsConfig.idsClaim, groupsConfig.namesClaim)
)
}

private def handleUserClaimSearchResult[B <: BlockContext : BlockContextUpdater](blockContext: B,
result: Option[ClaimSearchResult[User.Id]]) = {
result match {
case None => Right(blockContext)
case Some(Found(userId)) => Right(blockContext.withUserMetadata(_.withLoggedUser(DirectlyLoggedUser(userId))))
case Some(NotFound) => Left(())
}
}

private def handleGroupsClaimSearchResult[B <: BlockContext : BlockContextUpdater](blockContext: B,
result: Option[ClaimSearchResult[UniqueList[Group]]]) = {
(result, settings.permittedGroups) match {
case (None, Groups.Defined(_)) =>
Left(())
case (None, Groups.NotDefined) =>
Right(blockContext)
case (Some(NotFound), Groups.Defined(_)) =>
Left(())
case (Some(NotFound), Groups.NotDefined) =>
Right(blockContext) // if groups field is not found, we treat this situation as same as empty groups would be passed
case (Some(Found(groups)), Groups.Defined(groupsLogic)) =>
UniqueNonEmptyList.from(groups) match {
case Some(nonEmptyGroups) =>
groupsLogic.availableGroupsFrom(nonEmptyGroups) match {
case Some(matchedGroups) =>
checkIfCanContinueWithGroups(blockContext, UniqueList.from(matchedGroups))
.map(_.withUserMetadata(_.addAvailableGroups(matchedGroups)))
case None =>
Left(())
}
case None =>
Left(())
}
case (Some(Found(groups)), Groups.NotDefined) =>
checkIfCanContinueWithGroups(blockContext, groups)
}
}

private def checkIfCanContinueWithGroups[B <: BlockContext](blockContext: B,
groups: UniqueList[Group]) = {
UniqueNonEmptyList.from(groups.toList.map(_.id)) match {
case Some(nonEmptyGroups) if blockContext.isCurrentGroupEligible(GroupIds(nonEmptyGroups)) =>
Right(blockContext)
case Some(_) | None =>
Left(())
}
}
override val userIdCaseSensitivity: CaseSensitivity = authentication.userIdCaseSensitivity
}

object JwtAuthRule {

implicit case object Name extends RuleName[JwtAuthRule] {
override val name = Rule.Name("jwt_auth")
}

final case class Settings(jwt: JwtDef, permittedGroups: Groups)

sealed trait Groups

object Groups {
case object NotDefined extends Groups

final case class Defined(groupsLogic: GroupsLogic) extends Groups
}
}
Loading