-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
51 lines (36 loc) · 1.56 KB
/
main.py
File metadata and controls
51 lines (36 loc) · 1.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from lib.pid5 import Pid5
from lib.bigfive import BigFive
from lib.darktriad import Darktriad
import argparse
import os
from lib.implementations import OllamaImpl
from lib.implementations import OpenaiImpl
parser = argparse.ArgumentParser(description='Process the model name.')
parser.add_argument('--model', type=str, default='gpt-3.5-turbo',
help='the model to use')
parser.add_argument('--test', type=str, default='pid5',
help='the test to use')
parser.add_argument('--prompt', type=str, default=None,
help='the prompt to use')
parser.add_argument('--image', type=bool, default=False,
help='whether to generate images for each items')
parser.add_argument('--tts', type=bool, default=False,
help='whether to generate tts samples for each items')
parser.add_argument('--samples', type=int, default=220,
help='max number of samples')
parser.add_argument('--seed', type=int, default=int.from_bytes(os.urandom(8), byteorder="big"))
args = parser.parse_args()
valid_api_models = ["gpt-4", "gpt-3.5-turbo", "gpt-4o"]
valid_local_models = ["mistral", "dolphin-mixtral", "llama2", "llama3", "llama2-uncensored"]
if args.model not in valid_api_models + valid_local_models:
raise ValueError('Invalid model name')
if args.model in valid_local_models:
implementation = OllamaImpl()
else:
implementation = OpenaiImpl()
TESTS = {
Pid5.ID: Pid5,
BigFive.ID: BigFive,
Darktriad.ID: Darktriad
}
test = TESTS[args.test](args, implementation).answer()