summaryrefslogtreecommitdiffstats
path: root/valkey-loadmod.patch
blob: 74c016eacbedbd8ab51cf369e63e65e1ab896d65 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
From cae829f497550d175331d3d6cd0bbf4aed0965a4 Mon Sep 17 00:00:00 2001
From: Remi Collet <remi@remirepo.net>
Date: Sat, 4 Oct 2025 07:23:52 +0200
Subject: [PATCH] Fix #2678 don't add loadmodule when from config

only protect loadmodule from include files

Signed-off-by: Remi Collet <remi@remirepo.net>
---
 src/config.c |  8 ++++++--
 src/module.c | 17 ++++++++++++-----
 src/module.h |  4 ++--
 src/server.c |  2 +-
 4 files changed, 21 insertions(+), 10 deletions(-)

diff --git a/src/config.c b/src/config.c
index 93ef289e3..dc4116383 100644
--- a/src/config.c
+++ b/src/config.c
@@ -450,6 +450,8 @@ static int updateClientOutputBufferLimit(sds *args, int arg_len, const char **er
  * within conf file parsing. This is only needed to support the deprecated
  * abnormal aggregate `save T C` functionality. Remove in the future. */
 static int reading_config_file;
+/* support detecting include vs main config file */
+static int reading_include_file = 0;
 
 void loadServerConfigFromString(sds config) {
     deprecatedConfig deprecated_configs[] = {
@@ -541,7 +543,9 @@ void loadServerConfigFromString(sds config) {
 
         /* Execute config directives */
         if (!strcasecmp(argv[0], "include") && argc == 2) {
+            reading_include_file = 1;
             loadServerConfig(argv[1], 0, NULL);
+            reading_include_file = 0;
         } else if (!strcasecmp(argv[0], "rename-command") && argc == 3) {
             struct serverCommand *cmd = lookupCommandBySds(argv[1]);
 
@@ -574,7 +578,7 @@ void loadServerConfigFromString(sds config) {
                 goto loaderr;
             }
         } else if (!strcasecmp(argv[0], "loadmodule") && argc >= 2) {
-            moduleEnqueueLoadModule(argv[1], &argv[2], argc - 2);
+            moduleEnqueueLoadModule(argv[1], &argv[2], argc - 2, reading_include_file);
         } else if (strchr(argv[0], '.')) {
             if (argc < 2) {
                 err = "Module config specified without value";
@@ -1618,7 +1622,7 @@ void rewriteConfigLoadmoduleOption(struct rewriteConfigState *state) {
     while ((de = dictNext(di)) != NULL) {
         struct ValkeyModule *module = dictGetVal(de);
         line = moduleLoadQueueEntryToLoadmoduleOptionStr(module, "loadmodule");
-        rewriteConfigRewriteLine(state, "loadmodule", line, 1);
+        if (line) rewriteConfigRewriteLine(state, "loadmodule", line, 1);
     }
     dictReleaseIterator(di);
     /* Mark "loadmodule" as processed in case modules is empty. */
diff --git a/src/module.c b/src/module.c
index de5a5510e..2638268e1 100644
--- a/src/module.c
+++ b/src/module.c
@@ -84,6 +84,7 @@
 
 struct moduleLoadQueueEntry {
     sds path;
+    int from_include;
     int argc;
     robj **argv;
 };
@@ -679,7 +680,7 @@ void freeClientModuleData(client *c) {
     c->module_data = NULL;
 }
 
-void moduleEnqueueLoadModule(sds path, sds *argv, int argc) {
+void moduleEnqueueLoadModule(sds path, sds *argv, int argc, int from_include) {
     int i;
     struct moduleLoadQueueEntry *loadmod;
 
@@ -687,6 +688,7 @@ void moduleEnqueueLoadModule(sds path, sds *argv, int argc) {
     loadmod->argv = argc ? zmalloc(sizeof(robj *) * argc) : NULL;
     loadmod->path = sdsnew(path);
     loadmod->argc = argc;
+    loadmod->from_include = from_include;
     for (i = 0; i < argc; i++) {
         loadmod->argv[i] = createRawStringObject(argv[i], sdslen(argv[i]));
     }
@@ -697,6 +699,10 @@ sds moduleLoadQueueEntryToLoadmoduleOptionStr(ValkeyModule *module,
                                               const char *config_option_str) {
     sds line;
 
+    if (module->loadmod->from_include) {
+        /* no need to add as already from config */
+        return NULL;
+    }
     line = sdsnew(config_option_str);
     line = sdscatlen(line, " ", 1);
     line = sdscatsds(line, module->loadmod->path);
@@ -12629,7 +12635,7 @@ void moduleLoadFromQueue(void) {
     listRewind(server.loadmodule_queue, &li);
     while ((ln = listNext(&li))) {
         struct moduleLoadQueueEntry *loadmod = ln->value;
-        if (moduleLoad(loadmod->path, (void **)loadmod->argv, loadmod->argc, 0) == C_ERR) {
+        if (moduleLoad(loadmod->path, (void **)loadmod->argv, loadmod->argc, 0, loadmod->from_include) == C_ERR) {
             serverLog(LL_WARNING, "Can't load module from %s: server aborting", loadmod->path);
             exit(1);
         }
@@ -12818,7 +12824,7 @@ void moduleUnregisterCleanup(ValkeyModule *module) {
 
 /* Load a module and initialize it. On success C_OK is returned, otherwise
  * C_ERR is returned. */
-int moduleLoad(const char *path, void **module_argv, int module_argc, int is_loadex) {
+int moduleLoad(const char *path, void **module_argv, int module_argc, int is_loadex, int from_include) {
     int (*onload)(void *, void **, int);
     void *handle;
 
@@ -12893,6 +12899,7 @@ int moduleLoad(const char *path, void **module_argv, int module_argc, int is_loa
     ctx.module->loadmod->path = sdsnew(path);
     ctx.module->loadmod->argv = module_argc ? zmalloc(sizeof(robj *) * module_argc) : NULL;
     ctx.module->loadmod->argc = module_argc;
+    ctx.module->loadmod->from_include = from_include;
     for (int i = 0; i < module_argc; i++) {
         ctx.module->loadmod->argv[i] = module_argv[i];
         incrRefCount(ctx.module->loadmod->argv[i]);
@@ -13961,7 +13968,7 @@ void moduleCommand(client *c) {
             argv = &c->argv[3];
         }
 
-        if (moduleLoad(objectGetVal(c->argv[2]), (void **)argv, argc, 0) == C_OK)
+        if (moduleLoad(objectGetVal(c->argv[2]), (void **)argv, argc, 0, 0) == C_OK)
             addReply(c, shared.ok);
         else
             addReplyError(c, "Error loading the extension. Please check the server logs.");
@@ -13976,7 +13983,7 @@ void moduleCommand(client *c) {
         /* If this is a loadex command we want to populate server.module_configs_queue with
          * sds NAME VALUE pairs. We also want to increment argv to just after ARGS, if supplied. */
         if (parseLoadexArguments((ValkeyModuleString ***)&argv, &argc) == VALKEYMODULE_OK &&
-            moduleLoad(objectGetVal(c->argv[2]), (void **)argv, argc, 1) == C_OK)
+            moduleLoad(objectGetVal(c->argv[2]), (void **)argv, argc, 1, 0) == C_OK)
             addReply(c, shared.ok);
         else {
             dictEmpty(server.module_configs_queue, NULL);
diff --git a/src/module.h b/src/module.h
index c7ad384c6..3b12efbaf 100644
--- a/src/module.h
+++ b/src/module.h
@@ -169,7 +169,7 @@ static inline void moduleInitDigestContext(ValkeyModuleDigest *mdvar) {
     memset(mdvar->x, 0, sizeof(mdvar->x));
 }
 
-void moduleEnqueueLoadModule(sds path, sds *argv, int argc);
+void moduleEnqueueLoadModule(sds path, sds *argv, int argc, int from_include);
 sds moduleLoadQueueEntryToLoadmoduleOptionStr(ValkeyModule *module,
                                               const char *config_option_str);
 ValkeyModuleCtx *moduleAllocateContext(void);
@@ -181,7 +181,7 @@ void moduleFreeContext(ValkeyModuleCtx *ctx);
 void moduleInitModulesSystem(void);
 void moduleInitModulesSystemLast(void);
 void modulesCron(void);
-int moduleLoad(const char *path, void **argv, int argc, int is_loadex);
+int moduleLoad(const char *path, void **argv, int argc, int is_loadex, int from_include);
 int moduleUnload(sds name, const char **errmsg);
 void moduleUnloadAllModules(void);
 void moduleLoadFromQueue(void);
diff --git a/src/server.c b/src/server.c
index 881d83bad..a21d3ca9e 100644
--- a/src/server.c
+++ b/src/server.c
@@ -7630,7 +7630,7 @@ __attribute__((weak)) int main(int argc, char **argv) {
 #ifdef LUA_ENABLED
 #define LUA_LIB_STR STRINGIFY(LUA_LIB)
     if (scriptingEngineManagerFind("lua") == NULL) {
-        if (moduleLoad(LUA_LIB_STR, NULL, 0, 0) != C_OK) {
+        if (moduleLoad(LUA_LIB_STR, NULL, 0, 0, 1) != C_OK) {
             serverPanic("Lua engine initialization failed, check the server logs.");
         }
     }
-- 
2.53.0