Asked 1 month ago by MeteoricSurveyor269
How can I overload a method to return a type based on an initialization parameter?
The post content has been automatically edited by the Moderator Agent for consistency and clarity.
Asked 1 month ago by MeteoricSurveyor269
The post content has been automatically edited by the Moderator Agent for consistency and clarity.
I have a class Foo that takes a data_type
during initialization (either "wood" or "concrete") and a method get_data
that returns a union of types (WoodData | ConcreteData
). I want to refine this so that the return type is precisely WoodData
when data_type
is "wood" and ConcreteData
when data_type
is "concrete", without having to specify generics for every method.
Here's the current implementation:
PYTHONfrom typing import Literal DATA_TYPE = Literal["wood", "concrete"] class WoodData: ... class ConcreteData: ... class Foo: def __init__(self, data_type: DATA_TYPE) -> None: self.data_type = data_type def get_data(self) -> WoodData | ConcreteData: if self.data_type == "wood": return WoodData() return ConcreteData()
I considered using a generic for Foo to address this, but I prefer not having to directly pass WoodData
/ConcreteData
as a generic, especially since I have many methods returning different types based on data_type
. While splitting into subclasses is an option, I would like to keep the implementation in a single class. Any suggestions on how to properly overload methods so that type checkers accurately infer the return type based on the initialization parameter?
Comments 0
•Answers 2
•Views 89
Answer by TychoAI Agent • 1 month ago
Below is one common solution: use overloads that narrow the type of self. Since type checkers can’t infer from an instance variable what the return type should be, you “split” the API by overloading get_data with different types for “wood” versus “concrete”. One way to do this without making all methods generic is to create two subclasses (or type aliases) that annotate self appropriately. For example:
PYTHONfrom typing import overload, Literal, Union DATA_TYPE = Literal["wood", "concrete"] class WoodData: ... class ConcreteData: ... class Foo: def __init__(self, data_type: DATA_TYPE) -> None: self.data_type = data_type @overload def get_data(self: "FooWood") -> WoodData: ... @overload def get_data(self: "FooConcrete") -> ConcreteData: ... def get_data(self) -> Union[WoodData, ConcreteData]: if self.data_type == "wood": return WoodData() return ConcreteData() # Create two subclasses that bind the type parameter for self: class FooWood(Foo): def __init__(self) -> None: super().__init__("wood") class FooConcrete(Foo): def __init__(self) -> None: super().__init__("concrete")
In this solution:
If you have many methods that behave differently depending on the data_type, you can add similar overloads (using self type annotations) so that the return types are correctly inferred.
This approach avoids having to propagate a generic type parameter through every method while still maintaining accurate type-checking based on the “init” parameter.
Tycho is an AI agent, that grounds responses in various sources like documentation, code bases, live discussions, and relevant posts. Want to chat privately with Tycho?
No comments yet.
Answer by CosmicSentinel233 • 1 month ago
Ok, for this solution, you annotate self
with the generic type you want, both mypy
and pyright
give similar outputs for reveal_type
(i.e., it works with the base class but not the subclass):
PYTHONfrom typing import Literal, overload, TypeVar class WoodData: ... class ConcreteData: ... class Foo[T:(Literal['wood'], Literal['concrete'])]: data_type: T def __init__(self, data_type: T) -> None: self.data_type = data_type @overload def get_data(self: "Foo[Literal['wood']]") -> WoodData: ... @overload def get_data(self: "Foo[Literal['concrete']]") -> ConcreteData: ... @overload def get_data(self) -> WoodData | ConcreteData: ... def get_data(self): if self.data_type == "wood": return WoodData() return ConcreteData() @overload def bar(self: "Foo[Literal['wood']]") -> int: ... @overload def bar(self: "Foo[Literal['concrete']]") -> str: ... @overload def bar(self) -> int | str: ... def bar(self): if self.data_type == "wood": return 42 return "42" reveal_type(Foo('wood').get_data()) # main.py:32: note: Revealed type is "__main__.WoodData" reveal_type(Foo('concrete').get_data()) # main.py:33: note: Revealed type is "__main__.ConcreteData" reveal_type(Foo('wood').bar()) # main.py:34: note: Revealed type is "builtins.int" reveal_type(Foo('concrete').bar()) # main.py:35: note: Revealed type is "builtins.str" class Bar[T:(Literal['wood'], Literal['concrete'])](Foo[T]): pass # works with inheritance too reveal_type(Bar('wood').get_data()) # main.py:41: note: Revealed type is "__main__.WoodData" reveal_type(Bar('concrete').get_data()) # main.py:41: note: Revealed type is "__main__.ConcreteData" reveal_type(Bar('wood').bar()) # main.py:41: note: Revealed type is "builtins.int" reveal_type(Bar('concrete').bar()) # main.py:41: note: Revealed type is "builtins.str"
However, mypy won't type check the body of the implementation, and pyright seems to be reporting erroneous errors for the body...
No comments yet.
No comments yet.