dusantism-db commented on code in PR #49445:
URL: https://github.com/apache/spark/pull/49445#discussion_r1946691800


##########
sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingLocalVariableManager.scala:
##########
@@ -60,42 +60,52 @@ class SqlScriptingLocalVariableManager(context: 
SqlScriptingExecutionContext)
       initValue: Literal,
       identifier: Identifier): Unit = {
     def varDef = VariableDefinition(identifier, defaultValueSQL, initValue)
-    nameParts match {
-      // Unqualified case.
-      case Seq(name) =>
-        context.currentFrame.scopes
-          .findLast(_.variables.contains(name))
-          // Throw error if variable is not found. This shouldn't happen as 
the check is already
-          // done in SetVariableExec.
-          .orElse(throw unresolvedVariableError(nameParts, 
identifier.namespace().toIndexedSeq))
-          .map(_.variables.put(name, varDef))
+    findScopeOfVariable(nameParts)
+      .getOrElse(throw unresolvedVariableError(nameParts, 
identifier.namespace().toIndexedSeq))
+      .variables.put(nameParts.last, varDef)
+    }
+
+  override def get(nameParts: Seq[String]): Option[VariableDefinition] = {
+    findScopeOfVariable(nameParts).flatMap(_.variables.get(nameParts.last))
+  }
+
+  private def findScopeOfVariable(nameParts: Seq[String]): 
Option[SqlScriptingExecutionScope] = {
+    def isScopeOfVar(
+        nameParts: Seq[String],
+        scope: SqlScriptingExecutionScope
+    ): Boolean = nameParts match {
+      case Seq(name) => scope.variables.contains(name)
       // Qualified case.
-      case Seq(label, name) =>
-        context.currentFrame.scopes
-          .findLast(_.label == label)
-          .filter(_.variables.contains(name))
-          // Throw error if variable is not found. This shouldn't happen as 
the check is already
-          // done in SetVariableExec.
-          .orElse(throw unresolvedVariableError(nameParts, 
identifier.namespace().toIndexedSeq))
-          .map(_.variables.put(name, varDef))
+      case Seq(label, _) => scope.label == label
       case _ =>
-        throw SparkException.internalError("ScriptingVariableManager.set 
expects 1 or 2 nameParts.")
+        throw SparkException.internalError("ScriptingVariableManager expects 1 
or 2 nameParts.")
     }
-  }
 
-  override def get(nameParts: Seq[String]): Option[VariableDefinition] = 
nameParts match {
-    // Unqualified case.
-    case Seq(name) =>
-      context.currentFrame.scopes
-      .findLast(_.variables.contains(name))
-      .flatMap(_.variables.get(name))
-    // Qualified case.
-    case Seq(label, name) =>
-      context.currentFrame.scopes
-      .findLast(_.label == label)
-      .flatMap(_.variables.get(name))
-    case _ =>
-      throw SparkException.internalError("ScriptingVariableManager.get expects 
1 or 2 nameParts.")
+    // First search for variable in entire current frame.
+    val resCurrentFrame = context.currentFrame.scopes
+      .findLast(scope => isScopeOfVar(nameParts, scope))
+    if (resCurrentFrame.isDefined) {
+      return resCurrentFrame
+    }
+
+    // When searching in previous frames, for each frame we have to check only 
scopes before and
+    // including the scope where the previously checked frame is defined, as 
the frames
+    // should not access variables from scopes which are nested below it's 
definition.

Review Comment:
   Done.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to