diff --git a/comfy_extras/nodes_glsl.py b/comfy_extras/nodes_glsl.py index 2f029a76e..19a6251f4 100644 --- a/comfy_extras/nodes_glsl.py +++ b/comfy_extras/nodes_glsl.py @@ -311,6 +311,21 @@ class GLContext: if self._vao is not None: self._glBindVertexArray(self._vao) + def compile_shader(self, source: str, shader_type: int) -> int: + """Compile a shader and return its ID.""" + gl = self._gl + + shader = gl.glCreateShader(shader_type) + gl.glShaderSource(shader, source) + gl.glCompileShader(shader) + + if gl.glGetShaderiv(shader, gl.GL_COMPILE_STATUS) != gl.GL_TRUE: + error = gl.glGetShaderInfoLog(shader).decode() + gl.glDeleteShader(shader) + raise RuntimeError(f"Shader compilation failed:\n{error}") + + return shader + ########## class _GLContextGLFW(GLContext): @@ -500,25 +515,12 @@ class _GLContextOSMesa(GLContext): ############################################################ -def _compile_shader(source: str, shader_type: int) -> int: - """Compile a shader and return its ID.""" - shader = gl.glCreateShader(shader_type) - gl.glShaderSource(shader, source) - gl.glCompileShader(shader) - - if gl.glGetShaderiv(shader, gl.GL_COMPILE_STATUS) != gl.GL_TRUE: - error = gl.glGetShaderInfoLog(shader).decode() - gl.glDeleteShader(shader) - raise RuntimeError(f"Shader compilation failed:\n{error}") - - return shader - - def _create_program(vertex_source: str, fragment_source: str) -> int: """Create and link a shader program.""" - vertex_shader = _compile_shader(vertex_source, gl.GL_VERTEX_SHADER) + ctx = GLContext() + vertex_shader = ctx.compile_shader(vertex_source, gl.GL_VERTEX_SHADER) try: - fragment_shader = _compile_shader(fragment_source, gl.GL_FRAGMENT_SHADER) + fragment_shader = ctx.compile_shader(fragment_source, gl.GL_FRAGMENT_SHADER) except RuntimeError: gl.glDeleteShader(vertex_shader) raise