Skip to content

Commit 73a1106

Browse files
committed
[Executorch] Make module constructors uniform across
Pull Request resolved: #15729 Existing constructors dont compose well such that if you want data loader or data files constructor then you cannot get to override memory allocator. Fix that. ghstack-source-id: 327215742 @exported-using-ghexport Differential Revision: [D86120037](https://our.internmc.facebook.com/intern/diff/D86120037/)
1 parent 9b6ada3 commit 73a1106

File tree

2 files changed

+36
-12
lines changed

2 files changed

+36
-12
lines changed

extension/module/module.cpp

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,17 @@ runtime::Result<std::unique_ptr<runtime::DataLoader>> make_data_loader(
7878
Module::Module(
7979
const std::string& file_path,
8080
const LoadMode load_mode,
81-
std::unique_ptr<runtime::EventTracer> event_tracer)
81+
std::unique_ptr<runtime::EventTracer> event_tracer,
82+
std::unique_ptr<runtime::MemoryAllocator> memory_allocator,
83+
std::unique_ptr<runtime::MemoryAllocator> temp_allocator)
8284
: file_path_(file_path),
8385
load_mode_(load_mode),
84-
memory_allocator_(std::make_unique<MallocMemoryAllocator>()),
85-
temp_allocator_(std::make_unique<MallocMemoryAllocator>()),
86+
memory_allocator_(
87+
memory_allocator ? std::move(memory_allocator)
88+
: std::make_unique<MallocMemoryAllocator>()),
89+
temp_allocator_(
90+
temp_allocator ? std::move(temp_allocator)
91+
: std::make_unique<MallocMemoryAllocator>()),
8692
event_tracer_(std::move(event_tracer)) {
8793
runtime::runtime_init();
8894
}
@@ -91,11 +97,17 @@ Module::Module(
9197
const std::string& file_path,
9298
const std::string& data_map_path,
9399
const LoadMode load_mode,
94-
std::unique_ptr<runtime::EventTracer> event_tracer)
100+
std::unique_ptr<runtime::EventTracer> event_tracer,
101+
std::unique_ptr<runtime::MemoryAllocator> memory_allocator,
102+
std::unique_ptr<runtime::MemoryAllocator> temp_allocator)
95103
: file_path_(file_path),
96104
load_mode_(load_mode),
97-
memory_allocator_(std::make_unique<MallocMemoryAllocator>()),
98-
temp_allocator_(std::make_unique<MallocMemoryAllocator>()),
105+
memory_allocator_(
106+
memory_allocator ? std::move(memory_allocator)
107+
: std::make_unique<MallocMemoryAllocator>()),
108+
temp_allocator_(
109+
temp_allocator ? std::move(temp_allocator)
110+
: std::make_unique<MallocMemoryAllocator>()),
99111
event_tracer_(std::move(event_tracer)) {
100112
if (!data_map_path.empty()) {
101113
data_files_.push_back(data_map_path);
@@ -107,12 +119,18 @@ Module::Module(
107119
const std::string& file_path,
108120
std::vector<std::string> data_files,
109121
const LoadMode load_mode,
110-
std::unique_ptr<runtime::EventTracer> event_tracer)
122+
std::unique_ptr<runtime::EventTracer> event_tracer,
123+
std::unique_ptr<runtime::MemoryAllocator> memory_allocator,
124+
std::unique_ptr<runtime::MemoryAllocator> temp_allocator)
111125
: file_path_(file_path),
112126
data_files_(std::move(data_files)),
113127
load_mode_(load_mode),
114-
memory_allocator_(std::make_unique<MallocMemoryAllocator>()),
115-
temp_allocator_(std::make_unique<MallocMemoryAllocator>()),
128+
memory_allocator_(
129+
memory_allocator ? std::move(memory_allocator)
130+
: std::make_unique<MallocMemoryAllocator>()),
131+
temp_allocator_(
132+
temp_allocator ? std::move(temp_allocator)
133+
: std::make_unique<MallocMemoryAllocator>()),
116134
event_tracer_(std::move(event_tracer)) {
117135
runtime::runtime_init();
118136
}

extension/module/module.h

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,9 @@ class Module {
6363
explicit Module(
6464
const std::string& file_path,
6565
const LoadMode load_mode = LoadMode::File,
66-
std::unique_ptr<runtime::EventTracer> event_tracer = nullptr);
66+
std::unique_ptr<runtime::EventTracer> event_tracer = nullptr,
67+
std::unique_ptr<runtime::MemoryAllocator> memory_allocator = nullptr,
68+
std::unique_ptr<runtime::MemoryAllocator> temp_allocator = nullptr);
6769

6870
/**
6971
* Constructs an instance by loading a program from a file with specified
@@ -78,7 +80,9 @@ class Module {
7880
const std::string& file_path,
7981
const std::string& data_map_path,
8082
const LoadMode load_mode = LoadMode::File,
81-
std::unique_ptr<runtime::EventTracer> event_tracer = nullptr);
83+
std::unique_ptr<runtime::EventTracer> event_tracer = nullptr,
84+
std::unique_ptr<runtime::MemoryAllocator> memory_allocator = nullptr,
85+
std::unique_ptr<runtime::MemoryAllocator> temp_allocator = nullptr);
8286

8387
/**
8488
* Constructs an instance by loading a program from a file with specified
@@ -93,7 +97,9 @@ class Module {
9397
const std::string& file_path,
9498
std::vector<std::string> data_files,
9599
const LoadMode load_mode = LoadMode::File,
96-
std::unique_ptr<runtime::EventTracer> event_tracer = nullptr);
100+
std::unique_ptr<runtime::EventTracer> event_tracer = nullptr,
101+
std::unique_ptr<runtime::MemoryAllocator> memory_allocator = nullptr,
102+
std::unique_ptr<runtime::MemoryAllocator> temp_allocator = nullptr);
97103

98104
/**
99105
* Constructs an instance with the provided data loader and memory allocator.

0 commit comments

Comments
 (0)