Skip to content

Commit 0bce957

Browse files
+ Bug fix
1 parent 0f3fa6b commit 0bce957

File tree

1 file changed

+90
-71
lines changed

1 file changed

+90
-71
lines changed

src/check.py

Lines changed: 90 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
import discord
22
from discord import app_commands
33
from discord.ext import commands
4-
from discord import abc
5-
from config import Config
64

5+
from .config import Config
76
from typing import (
8-
NoReturn, Literal, Union, Optional,
7+
Any, NoReturn, Literal, Union, Optional,
98
List, Sequence, TypedDict
109
)
1110

@@ -21,7 +20,7 @@ class PermissionRequirement(TypedDict):
2120
:param value: The permission value
2221
"""
2322
type: Literal['wl', 'bl']
24-
query: Literal['in_channel', 'has_role', 'has_permission', 'is_developer', 'minimum_role']
23+
query: Literal['in_channel', 'in_guild', 'has_role', 'has_permission', 'is_developer', 'minimum_role']
2524
value: Optional[Union[str, int]]
2625

2726
class PermissionGate(TypedDict):
@@ -44,7 +43,7 @@ class HybridContext:
4443
author: Union[discord.User, discord.Member]
4544
is_developer: bool
4645
guild: Optional[discord.Guild]
47-
channel: Optional[Union[abc.GuildChannel, abc.PrivateChannel, discord.Thread]]
46+
channel: Any
4847

4948
def __init__(self, __ctx: Union[discord.Interaction, commands.Context]) -> None:
5049
"""
@@ -54,62 +53,72 @@ def __init__(self, __ctx: Union[discord.Interaction, commands.Context]) -> None:
5453
----------
5554
:param __ctx: The command context
5655
"""
57-
self.author = (isinstance(__ctx, discord.Interaction) and __ctx.user) or __ctx.author
56+
self.author = __ctx.user if isinstance(__ctx, discord.Interaction) else __ctx.author
5857
self.is_developer = (self.author.id == Config.DEVELOPER_USER_ID)
5958
self.guild = __ctx.guild
6059
self.channel = __ctx.channel
6160

6261

6362

64-
def _validate_group(ctx: HybridContext, group: Sequence[PermissionGate]) -> bool:
63+
def _validate_group(ctx: HybridContext, group: List[PermissionGate]) -> bool:
6564
required: List[bool] = []
6665
optional: List[bool] = []
6766

6867
for gate in group:
69-
passFlag = False
68+
passFlag: bool = False
7069

7170
match gate['requirement']['query']:
7271
case 'is_developer':
73-
passFlag = (gate['requirement']['type'] == 'wl') and ctx.is_developer
74-
break
72+
passFlag = ((gate['requirement']['type'] == 'wl') and ctx.is_developer)
7573

7674
case 'has_role':
77-
if ctx.guild:
75+
if ctx.guild and isinstance(ctx.author, discord.Member):
7876
passFlag = (gate['requirement']['type'] == 'wl') and any(
7977
gate['requirement']['value'] in [role.id, role.name]
8078
for role in ctx.author.roles
8179
)
8280

8381
case 'has_permission':
84-
if ctx.guild:
82+
if ctx.guild and isinstance(ctx.author, discord.Member):
8583
passFlag = (gate['requirement']['type'] == 'wl') and getattr(
8684
ctx.author.guild_permissions,
87-
gate['requirement']['value']
85+
str(gate['requirement']['value'])
8886
)
8987

9088
case 'in_channel':
91-
if ctx.guild:
89+
if ctx.guild and isinstance(ctx.author, discord.Member):
9290
passFlag = (gate['requirement']['type'] == 'wl') and (
9391
gate['requirement']['value'] == ctx.channel.id
9492
)
9593

94+
case 'in_guild':
95+
if ctx.guild and isinstance(ctx.author, discord.Member):
96+
passFlag = (gate['requirement']['type'] == 'wl') and (
97+
gate['requirement']['value'] in
98+
[ctx.guild.id, ctx.guild.name]
99+
)
100+
96101
case 'minimum_role':
97-
if ctx.guild:
98-
role = ctx.guild.get_role(gate['requirement']['value'])
99-
passFlag = (gate['requirement']['type'] == 'wl') and role and (
102+
if ctx.guild and gate['requirement']['value'] and isinstance(ctx.author, discord.Member):
103+
role = ctx.guild.get_role(int(gate['requirement']['value']))
104+
passFlag = bool((gate['requirement']['type'] == 'wl') and role and (
100105
ctx.author.top_role >= role
101-
)
106+
))
102107

103108
if gate['type'] == 'required':
104109
required.append(passFlag)
105110
else:
106111
optional.append(passFlag)
107112

108-
return bool(all(required) and any(optional))
113+
return (
114+
(((len(required) > 0) and all(required)) or True)
115+
and
116+
(((len(optional) > 0) and any(optional)) or True)
117+
)
109118

110119

111120

112-
def validate(__ctx: Union[discord.Interaction, commands.Context], *gates: Sequence[Union[PermissionGate, Sequence[PermissionGate]]]) -> bool:
121+
def validate(__ctx: Union[discord.Interaction, commands.Context], *gates: Union[PermissionGate, Sequence[PermissionGate]]) -> bool:
113122
"""
114123
Validates a members access to a command
115124
@@ -118,88 +127,98 @@ def validate(__ctx: Union[discord.Interaction, commands.Context], *gates: Sequen
118127
:param __ctx: Command Context [legacy and app command supported]
119128
:param gates: List of permission gates
120129
"""
121-
ctx = HybridContext(__ctx)
122-
123-
ungrouped: List[PermissionGate] = []
124-
grouped: List[List[PermissionGate]] = []
125-
126-
# Separate grouped and ungrouped
127-
for gate in gates:
128-
if isinstance(gate, PermissionGate):
129-
ungrouped.append(gate)
130-
else:
131-
grouped.append(gate)
132-
133-
# Validate
134-
joined: List[List[PermissionGate]] = [ungrouped, *grouped]
135-
for gate in joined:
136-
passFlag = _validate_group(ctx, gate)
137-
if not passFlag:
138-
return False
139-
140-
return True
141-
142-
143-
144-
class Protected():
130+
try:
131+
ctx = HybridContext(__ctx)
132+
grouped: List[List[PermissionGate]] = [[]]
133+
134+
# Separate grouped and ungrouped
135+
for gate in gates:
136+
if isinstance(gate, Sequence):
137+
grouped.append(list(gate))
138+
else:
139+
grouped[0].append(gate)
140+
141+
# Validate
142+
for gate in grouped:
143+
if len(gate) > 0:
144+
passFlag = _validate_group(ctx, gate)
145+
if not passFlag:
146+
return False
147+
148+
return True
149+
except Exception as e:
150+
raise commands.CheckFailure(f'{e}') from e
151+
152+
153+
class Protected:
154+
@staticmethod
145155
def app(*clauses: Union[PermissionGate, Sequence[PermissionGate]]):
146156
"""
147157
Protect app command usage
148158
149159
Unclaused permission gates are treated as one clause
150-
All clauses muts pass for the user to be allowed access to the command
160+
All clauses must pass for the user to be allowed access to the command
151161
152162
Parameters
153163
----------
154164
:param *: Permission gates or clausees of permission gates
155165
"""
156-
async def _run(interaction):
166+
async def predicate(interaction):
157167
return validate(interaction, *clauses)
158-
return app_commands.check(_run)
159-
168+
return app_commands.check(predicate)
169+
170+
@staticmethod
160171
def legacy(*clauses: Union[PermissionGate, Sequence[PermissionGate]]):
161172
"""
162173
Protect legacy command usage
163174
164175
Unclaused permission gates are treated as one clause
165-
All clauses muts pass for the user to be allowed access to the command
176+
All clauses must pass for the user to be allowed access to the command
166177
167178
Parameters
168179
----------
169180
:param *: Permission gates or clausees of permission gates
170181
"""
171-
def _run(ctx):
182+
def predicate(ctx):
172183
return validate(ctx, *clauses)
173-
return commands.check(_run)
184+
return commands.check(predicate)
174185

175186

176187

177188
class PermissionPreset:
178189
"""Permission Presets for repeatedly used permissions"""
179190

180-
Developer: PermissionRequirement = {
181-
'origin': 'guild',
182-
'type': 'wl',
183-
'query': 'is_developer',
184-
'value': None
191+
Developer: PermissionGate = {
192+
'type': 'required',
193+
'requirement': {
194+
'type': 'wl',
195+
'query': 'is_developer',
196+
'value': None
197+
}
185198
}
186-
Admin: PermissionRequirement = {
187-
'origin': 'guild',
188-
'type': 'wl',
189-
'query': 'has_permission',
190-
'value': 'administrator'
199+
Admin: PermissionGate = {
200+
'type': 'required',
201+
'requirement': PermissionRequirement(
202+
type = 'wl',
203+
query = 'has_permission',
204+
value = 'administrator'
205+
)
191206
}
192-
WithinServer: PermissionRequirement = {
193-
'origin': 'guild',
194-
'type': 'wl',
195-
'query': 'in_guild',
196-
'value': Config.GUILD_ID
207+
WithinServer: PermissionGate = {
208+
'type': 'required',
209+
'requirement': PermissionRequirement(
210+
type = 'wl',
211+
query = 'in_guild',
212+
value = Config.GUILD_ID
213+
)
197214
}
198-
Is_Member: PermissionRequirement = {
199-
'origin': 'data',
200-
'type': 'wl',
201-
'query': 'minimum_role',
202-
'value': 'Community'
215+
Is_Member: PermissionGate = {
216+
'type': 'required',
217+
'requirement': PermissionRequirement(
218+
type = 'wl',
219+
query = 'minimum_role',
220+
value = 'Community'
221+
)
203222
}
204223

205224
def __init__(self, *args, **kwargs) -> NoReturn:

0 commit comments

Comments
 (0)