## Licensed to the Apache Software Foundation (ASF) under one or more# contributor license agreements. See the NOTICE file distributed with# this work for additional information regarding copyright ownership.# The ASF licenses this file to You under the Apache License, Version 2.0# (the "License"); you may not use this file except in compliance with# the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.#importsysimportselectimportstructimportsocketserverasSocketServerimportthreadingfromtypingimportCallable,Dict,Generic,Tuple,Type,TYPE_CHECKING,TypeVar,Unionfrompyspark.serializersimportread_int,CPickleSerializerfrompyspark.errorsimportPySparkRuntimeErrorifTYPE_CHECKING:frompyspark._typingimportSupportsIAdd# noqa: F401importsocketserver.BaseRequestHandler# type: ignore[import]__all__=["Accumulator","AccumulatorParam"]T=TypeVar("T")U=TypeVar("U",bound="SupportsIAdd")pickleSer=CPickleSerializer()# Holds accumulators registered on the current machine, keyed by ID. This is then used to send# the local accumulator updates back to the driver program at the end of a task._accumulatorRegistry:Dict[int,"Accumulator"]={}def_deserialize_accumulator(aid:int,zero_value:T,accum_param:"AccumulatorParam[T]")->"Accumulator[T]":frompyspark.accumulatorsimport_accumulatorRegistry# If this certain accumulator was deserialized, don't overwrite it.ifaidin_accumulatorRegistry:return_accumulatorRegistry[aid]else:accum=Accumulator(aid,zero_value,accum_param)accum._deserialized=True_accumulatorRegistry[aid]=accumreturnaccum
[docs]classAccumulator(Generic[T]):""" A shared variable that can be accumulated, i.e., has a commutative and associative "add" operation. Worker tasks on a Spark cluster can add values to an Accumulator with the `+=` operator, but only the driver program is allowed to access its value, using `value`. Updates from the workers get propagated automatically to the driver program. While :class:`SparkContext` supports accumulators for primitive data types like :class:`int` and :class:`float`, users can also define accumulators for custom types by providing a custom :py:class:`AccumulatorParam` object. Refer to its doctest for an example. Examples -------- >>> a = sc.accumulator(1) >>> a.value 1 >>> a.value = 2 >>> a.value 2 >>> a += 5 >>> a.value 7 >>> sc.accumulator(1.0).value 1.0 >>> sc.accumulator(1j).value 1j >>> rdd = sc.parallelize([1,2,3]) >>> def f(x): ... global a ... a += x ... >>> rdd.foreach(f) >>> a.value 13 >>> b = sc.accumulator(0) >>> def g(x): ... b.add(x) ... >>> rdd.foreach(g) >>> b.value 6 >>> rdd.map(lambda x: a.value).collect() # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... Py4JJavaError: ... >>> def h(x): ... global a ... a.value = 7 ... >>> rdd.foreach(h) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... Py4JJavaError: ... >>> sc.accumulator([1.0, 2.0, 3.0]) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... TypeError: ... """def__init__(self,aid:int,value:T,accum_param:"AccumulatorParam[T]"):"""Create a new Accumulator with a given initial value and AccumulatorParam object"""frompyspark.accumulatorsimport_accumulatorRegistryself.aid=aidself.accum_param=accum_paramself._value=valueself._deserialized=False_accumulatorRegistry[aid]=selfdef__reduce__(self,)->Tuple[Callable[[int,T,"AccumulatorParam[T]"],"Accumulator[T]"],Tuple[int,T,"AccumulatorParam[T]"],]:"""Custom serialization; saves the zero value from our AccumulatorParam"""param=self.accum_paramreturn(_deserialize_accumulator,(self.aid,param.zero(self._value),param))@propertydefvalue(self)->T:"""Get the accumulator's value; only usable in driver program"""ifself._deserialized:raisePySparkRuntimeError(error_class="VALUE_NOT_ACCESSIBLE",message_parameters={"value":"Accumulator.value",},)returnself._value@value.setterdefvalue(self,value:T)->None:"""Sets the accumulator's value; only usable in driver program"""ifself._deserialized:raisePySparkRuntimeError(error_class="VALUE_NOT_ACCESSIBLE",message_parameters={"value":"Accumulator.value",},)self._value=value
[docs]defadd(self,term:T)->None:"""Adds a term to this accumulator's value"""self._value=self.accum_param.addInPlace(self._value,term)
def__iadd__(self,term:T)->"Accumulator[T]":"""The += operator; adds a term to this accumulator's value"""self.add(term)returnselfdef__str__(self)->str:returnstr(self._value)def__repr__(self)->str:return"Accumulator<id=%i, value=%s>"%(self.aid,self._value)
[docs]classAccumulatorParam(Generic[T]):""" Helper object that defines how to accumulate values of a given type. Examples -------- >>> from pyspark.accumulators import AccumulatorParam >>> class VectorAccumulatorParam(AccumulatorParam): ... def zero(self, value): ... return [0.0] * len(value) ... def addInPlace(self, val1, val2): ... for i in range(len(val1)): ... val1[i] += val2[i] ... return val1 >>> va = sc.accumulator([1.0, 2.0, 3.0], VectorAccumulatorParam()) >>> va.value [1.0, 2.0, 3.0] >>> def g(x): ... global va ... va += [x] * 3 ... >>> rdd = sc.parallelize([1,2,3]) >>> rdd.foreach(g) >>> va.value [7.0, 8.0, 9.0] """
[docs]defzero(self,value:T)->T:""" Provide a "zero value" for the type, compatible in dimensions with the provided `value` (e.g., a zero vector) """raiseNotImplementedError
[docs]defaddInPlace(self,value1:T,value2:T)->T:""" Add two values of the accumulator's data type, returning a new value; for efficiency, can also update `value1` in place and return it. """raiseNotImplementedError
classAddingAccumulatorParam(AccumulatorParam[U]):""" An AccumulatorParam that uses the + operators to add values. Designed for simple types such as integers, floats, and lists. Requires the zero value for the underlying type as a parameter. """def__init__(self,zero_value:U):self.zero_value=zero_valuedefzero(self,value:U)->U:returnself.zero_valuedefaddInPlace(self,value1:U,value2:U)->U:value1+=value2# type: ignore[operator]returnvalue1# Singleton accumulator params for some standard typesINT_ACCUMULATOR_PARAM=AddingAccumulatorParam(0)# type: ignore[type-var]FLOAT_ACCUMULATOR_PARAM=AddingAccumulatorParam(0.0)# type: ignore[type-var]COMPLEX_ACCUMULATOR_PARAM=AddingAccumulatorParam(0.0j)# type: ignore[type-var]class_UpdateRequestHandler(SocketServer.StreamRequestHandler):""" This handler will keep polling updates from the same socket until the server is shutdown. """defhandle(self)->None:frompyspark.accumulatorsimport_accumulatorRegistryauth_token=self.server.auth_token# type: ignore[attr-defined]defpoll(func:Callable[[],bool])->None:whilenotself.server.server_shutdown:# type: ignore[attr-defined]# Poll every 1 second for new data -- don't block in case of shutdown.r,_,_=select.select([self.rfile],[],[],1)ifself.rfileinrandfunc():breakdefaccum_updates()->bool:num_updates=read_int(self.rfile)for_inrange(num_updates):(aid,update)=pickleSer._read_with_length(self.rfile)_accumulatorRegistry[aid]+=update# Write a byte in acknowledgementself.wfile.write(struct.pack("!b",1))returnFalsedefauthenticate_and_accum_updates()->bool:received_token:Union[bytes,str]=self.rfile.read(len(auth_token))ifisinstance(received_token,bytes):received_token=received_token.decode("utf-8")ifreceived_token==auth_token:accum_updates()# we've authenticated, we can break out of the first loop nowreturnTrueelse:raiseValueError("The value of the provided token to the AccumulatorServer is not correct.")# first we keep polling till we've received the authentication tokenpoll(authenticate_and_accum_updates)# now we've authenticated, don't need to check for the token anymorepoll(accum_updates)classAccumulatorServer(SocketServer.TCPServer):def__init__(self,server_address:Tuple[str,int],RequestHandlerClass:Type["socketserver.BaseRequestHandler"],auth_token:str,):SocketServer.TCPServer.__init__(self,server_address,RequestHandlerClass)self.auth_token=auth_token""" A simple TCP server that intercepts shutdown() in order to interrupt our continuous polling on the handler. """server_shutdown=Falsedefshutdown(self)->None:self.server_shutdown=TrueSocketServer.TCPServer.shutdown(self)self.server_close()def_start_update_server(auth_token:str)->AccumulatorServer:"""Start a TCP server to receive accumulator updates in a daemon thread, and returns it"""server=AccumulatorServer(("localhost",0),_UpdateRequestHandler,auth_token)thread=threading.Thread(target=server.serve_forever)thread.daemon=Truethread.start()returnserverif__name__=="__main__":importdoctestfrompyspark.contextimportSparkContextglobs=globals().copy()# The small batch size here ensures that we see multiple batches,# even in these small test examples:globs["sc"]=SparkContext("local","test")(failure_count,test_count)=doctest.testmod(globs=globs,optionflags=doctest.ELLIPSIS)globs["sc"].stop()iffailure_count:sys.exit(-1)