Skip to content

Commit 16212e2

Browse files
authored
Making Text classification template similar to Image Classification (#92)
* Fix template sidebar * Code style
1 parent 1c112fc commit 16212e2

File tree

1 file changed

+55
-52
lines changed

1 file changed

+55
-52
lines changed

templates/text_classification/_sidebar.py

Lines changed: 55 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -18,63 +18,66 @@ def get_configs() -> dict:
1818
config["eval_epoch_length"] = None
1919
default_none_options(config)
2020

21-
st.header("Transformer")
21+
with st.beta_expander("Text Classification Template Configurations", expanded=True):
22+
st.info("Names in the parenthesis are variable names used in the generated code.")
2223

23-
st.subheader("Model Options")
24-
config["model"] = st.selectbox(
25-
"Model name (from transformers) to setup model, tokenize and config to train (model)",
26-
options=["bert-base-uncased"],
27-
)
28-
config["model_dir"] = st.text_input("Cache directory to download the pretrained model (model_dir)", value="./")
29-
config["tokenizer_dir"] = st.text_input("Tokenizer cache directory (tokenizer_dir)", value="./tokenizer")
30-
config["num_classes"] = st.number_input(
31-
"Number of target classes. Default, 1 (binary classification) (num_classes)", min_value=0, value=1
32-
)
33-
config["max_length"] = st.number_input(
34-
"Maximum number of tokens for the inputs to the transformer model (max_length)", min_value=1, value=256
35-
)
36-
config["dropout"] = st.number_input(
37-
"Dropout probability (dropout)", min_value=0.0, max_value=1.0, value=0.3, format="%f"
38-
)
39-
config["n_fc"] = st.number_input(
40-
"Number of neurons in the last fully connected layer (n_fc)", min_value=1, value=768
41-
)
42-
st.markdown("---")
24+
st.subheader("Model Options")
25+
config["model"] = st.selectbox(
26+
"Model name (from transformers) to setup model, tokenize and config to train (model)",
27+
options=["bert-base-uncased"],
28+
)
29+
config["model_dir"] = st.text_input("Cache directory to download the pretrained model (model_dir)", value="./")
30+
config["tokenizer_dir"] = st.text_input("Tokenizer cache directory (tokenizer_dir)", value="./tokenizer")
31+
config["num_classes"] = st.number_input(
32+
"Number of target classes. Default, 1 (binary classification) (num_classes)", min_value=0, value=1
33+
)
34+
config["max_length"] = st.number_input(
35+
"Maximum number of tokens for the inputs to the transformer model (max_length)", min_value=1, value=256
36+
)
37+
config["dropout"] = st.number_input(
38+
"Dropout probability (dropout)", min_value=0.0, max_value=1.0, value=0.3, format="%f"
39+
)
40+
config["n_fc"] = st.number_input(
41+
"Number of neurons in the last fully connected layer (n_fc)", min_value=1, value=768
42+
)
43+
st.markdown("---")
4344

44-
st.subheader("Dataset Options")
45-
config["data_dir"] = st.text_input("Dataset cache directory (data_dir)", value="./")
46-
st.markdown("---")
45+
st.subheader("Dataset Options")
46+
config["data_dir"] = st.text_input("Dataset cache directory (data_dir)", value="./")
47+
st.markdown("---")
4748

48-
st.subheader("DataLoader Options")
49-
config["batch_size"] = st.number_input("Total batch size (batch_size)", min_value=1, value=4)
50-
config["num_workers"] = st.number_input("Number of workers in the data loader (num_workers)", min_value=1, value=2)
51-
st.markdown("---")
49+
st.subheader("DataLoader Options")
50+
config["batch_size"] = st.number_input("Total batch size (batch_size)", min_value=1, value=4)
51+
config["num_workers"] = st.number_input(
52+
"Number of workers in the data loader (num_workers)", min_value=1, value=2
53+
)
54+
st.markdown("---")
5255

53-
st.subheader("Optimizer Options")
54-
config["learning_rate"] = st.number_input(
55-
"Peak of piecewise linear learning rate scheduler", min_value=0.0, value=5e-5, format="%e"
56-
)
57-
config["weight_decay"] = st.number_input("Weight decay", min_value=0.0, value=0.01, format="%f")
58-
st.markdown("---")
56+
st.subheader("Optimizer Options")
57+
config["learning_rate"] = st.number_input(
58+
"Peak of piecewise linear learning rate scheduler", min_value=0.0, value=5e-5, format="%e"
59+
)
60+
config["weight_decay"] = st.number_input("Weight decay", min_value=0.0, value=0.01, format="%f")
61+
st.markdown("---")
5962

60-
st.subheader("Training Options")
61-
config["max_epochs"] = st.number_input("Number of epochs to train the model", min_value=1, value=3)
62-
config["num_warmup_epochs"] = st.number_input(
63-
"Number of warm-up epochs before learning rate decay", min_value=0, value=0
64-
)
65-
config["validate_every"] = st.number_input(
66-
"Run model's validation every validate_every epochs", min_value=0, value=1
67-
)
68-
config["checkpoint_every"] = st.number_input(
69-
"Store training checkpoint every checkpoint_every iterations", min_value=0, value=1000
70-
)
71-
config["log_every_iters"] = st.number_input(
72-
"Argument to log batch loss every log_every_iters iterations. 0 to disable it", min_value=0, value=15
73-
)
74-
st.markdown("---")
63+
st.subheader("Training Options")
64+
config["max_epochs"] = st.number_input("Number of epochs to train the model", min_value=1, value=3)
65+
config["num_warmup_epochs"] = st.number_input(
66+
"Number of warm-up epochs before learning rate decay", min_value=0, value=0
67+
)
68+
config["validate_every"] = st.number_input(
69+
"Run model's validation every validate_every epochs", min_value=0, value=1
70+
)
71+
config["checkpoint_every"] = st.number_input(
72+
"Store training checkpoint every checkpoint_every iterations", min_value=0, value=1000
73+
)
74+
config["log_every_iters"] = st.number_input(
75+
"Argument to log batch loss every log_every_iters iterations. 0 to disable it", min_value=0, value=15
76+
)
77+
st.markdown("---")
7578

76-
distributed_options(config)
77-
ignite_handlers_options(config)
78-
ignite_loggers_options(config)
79+
distributed_options(config)
80+
ignite_handlers_options(config)
81+
ignite_loggers_options(config)
7982

8083
return config

0 commit comments

Comments
 (0)