Skip to content

Commit 29c357c

Browse files
committed
[Executorch] Make module constructors uniform across
Pull Request resolved: #15729 Existing constructors dont compose well such that if you want data loader or data files constructor then you cannot get to override memory allocator. Fix that. ghstack-source-id: 324721909 @exported-using-ghexport Differential Revision: [D86120037](https://our.internmc.facebook.com/intern/diff/D86120037/)
1 parent b49ba17 commit 29c357c

File tree

2 files changed

+30
-6
lines changed

2 files changed

+30
-6
lines changed

extension/module/module.cpp

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,17 @@ runtime::Result<std::unique_ptr<runtime::DataLoader>> make_data_loader(
7878
Module::Module(
7979
const std::string& file_path,
8080
const LoadMode load_mode,
81+
std::unique_ptr<runtime::MemoryAllocator> memory_allocator,
82+
std::unique_ptr<runtime::MemoryAllocator> temp_allocator,
8183
std::unique_ptr<runtime::EventTracer> event_tracer)
8284
: file_path_(file_path),
8385
load_mode_(load_mode),
84-
memory_allocator_(std::make_unique<MallocMemoryAllocator>()),
85-
temp_allocator_(std::make_unique<MallocMemoryAllocator>()),
86+
memory_allocator_(
87+
memory_allocator ? std::move(memory_allocator)
88+
: std::make_unique<MallocMemoryAllocator>()),
89+
temp_allocator_(
90+
temp_allocator ? std::move(temp_allocator)
91+
: std::make_unique<MallocMemoryAllocator>()),
8692
event_tracer_(std::move(event_tracer)) {
8793
runtime::runtime_init();
8894
}
@@ -91,11 +97,17 @@ Module::Module(
9197
const std::string& file_path,
9298
const std::string& data_map_path,
9399
const LoadMode load_mode,
100+
std::unique_ptr<runtime::MemoryAllocator> memory_allocator,
101+
std::unique_ptr<runtime::MemoryAllocator> temp_allocator,
94102
std::unique_ptr<runtime::EventTracer> event_tracer)
95103
: file_path_(file_path),
96104
load_mode_(load_mode),
97-
memory_allocator_(std::make_unique<MallocMemoryAllocator>()),
98-
temp_allocator_(std::make_unique<MallocMemoryAllocator>()),
105+
memory_allocator_(
106+
memory_allocator ? std::move(memory_allocator)
107+
: std::make_unique<MallocMemoryAllocator>()),
108+
temp_allocator_(
109+
temp_allocator ? std::move(temp_allocator)
110+
: std::make_unique<MallocMemoryAllocator>()),
99111
event_tracer_(std::move(event_tracer)) {
100112
if (!data_map_path.empty()) {
101113
data_files_.push_back(data_map_path);
@@ -107,12 +119,18 @@ Module::Module(
107119
const std::string& file_path,
108120
std::vector<std::string> data_files,
109121
const LoadMode load_mode,
122+
std::unique_ptr<runtime::MemoryAllocator> memory_allocator,
123+
std::unique_ptr<runtime::MemoryAllocator> temp_allocator,
110124
std::unique_ptr<runtime::EventTracer> event_tracer)
111125
: file_path_(file_path),
112126
data_files_(std::move(data_files)),
113127
load_mode_(load_mode),
114-
memory_allocator_(std::make_unique<MallocMemoryAllocator>()),
115-
temp_allocator_(std::make_unique<MallocMemoryAllocator>()),
128+
memory_allocator_(
129+
memory_allocator ? std::move(memory_allocator)
130+
: std::make_unique<MallocMemoryAllocator>()),
131+
temp_allocator_(
132+
temp_allocator ? std::move(temp_allocator)
133+
: std::make_unique<MallocMemoryAllocator>()),
116134
event_tracer_(std::move(event_tracer)) {
117135
runtime::runtime_init();
118136
}

extension/module/module.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ class Module {
6363
explicit Module(
6464
const std::string& file_path,
6565
const LoadMode load_mode = LoadMode::File,
66+
std::unique_ptr<runtime::MemoryAllocator> memory_allocator = nullptr,
67+
std::unique_ptr<runtime::MemoryAllocator> temp_allocator = nullptr,
6668
std::unique_ptr<runtime::EventTracer> event_tracer = nullptr);
6769

6870
/**
@@ -78,6 +80,8 @@ class Module {
7880
const std::string& file_path,
7981
const std::string& data_map_path,
8082
const LoadMode load_mode = LoadMode::File,
83+
std::unique_ptr<runtime::MemoryAllocator> memory_allocator = nullptr,
84+
std::unique_ptr<runtime::MemoryAllocator> temp_allocator = nullptr,
8185
std::unique_ptr<runtime::EventTracer> event_tracer = nullptr);
8286

8387
/**
@@ -93,6 +97,8 @@ class Module {
9397
const std::string& file_path,
9498
std::vector<std::string> data_files,
9599
const LoadMode load_mode = LoadMode::File,
100+
std::unique_ptr<runtime::MemoryAllocator> memory_allocator = nullptr,
101+
std::unique_ptr<runtime::MemoryAllocator> temp_allocator = nullptr,
96102
std::unique_ptr<runtime::EventTracer> event_tracer = nullptr);
97103

98104
/**

0 commit comments

Comments
 (0)